{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Chapter 12 – Custom Models and Training with TensorFlow**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "_This notebook contains all the sample code in chapter 12._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<table align=\"left\">\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/ageron/handson-ml2/blob/master/12_custom_models_and_training_with_tensorflow.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Python ≥3.5 is required\n",
    "import sys\n",
    "assert sys.version_info >= (3, 5)\n",
    "\n",
    "# Scikit-Learn ≥0.20 is required\n",
    "import sklearn\n",
    "assert sklearn.__version__ >= \"0.20\"\n",
    "\n",
    "try:\n",
    "    # %tensorflow_version only exists in Colab.\n",
    "    %tensorflow_version 2.x\n",
    "except Exception:\n",
    "    pass\n",
    "\n",
    "# TensorFlow ≥2.0 is required\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "assert tf.__version__ >= \"2.0\"\n",
    "\n",
    "# Common imports\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "# to make this notebook's output stable across runs\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)\n",
    "\n",
    "# To plot pretty figures\n",
    "%matplotlib inline\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "mpl.rc('axes', labelsize=14)\n",
    "mpl.rc('xtick', labelsize=12)\n",
    "mpl.rc('ytick', labelsize=12)\n",
    "\n",
    "# Where to save the figures\n",
    "PROJECT_ROOT_DIR = \".\"\n",
    "CHAPTER_ID = \"deep\"\n",
    "IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
    "os.makedirs(IMAGES_PATH, exist_ok=True)\n",
    "\n",
    "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
    "    path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n",
    "    print(\"Saving figure\", fig_id)\n",
    "    if tight_layout:\n",
    "        plt.tight_layout()\n",
    "    plt.savefig(path, format=fig_extension, dpi=resolution)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tensors and operations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 3), dtype=float32, numpy=\n",
       "array([[1., 2., 3.],\n",
       "       [4., 5., 6.]], dtype=float32)>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.constant([[1., 2., 3.], [4., 5., 6.]]) # matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=int32, numpy=42>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.constant(42) # scalar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 3), dtype=float32, numpy=\n",
       "array([[1., 2., 3.],\n",
       "       [4., 5., 6.]], dtype=float32)>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = tf.constant([[1., 2., 3.], [4., 5., 6.]])\n",
    "t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([2, 3])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tf.float32"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t.dtype"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Indexing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[2., 3.],\n",
       "       [5., 6.]], dtype=float32)>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t[:, 1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 1), dtype=float32, numpy=\n",
       "array([[2.],\n",
       "       [5.]], dtype=float32)>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t[..., 1, tf.newaxis]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 3), dtype=float32, numpy=\n",
       "array([[11., 12., 13.],\n",
       "       [14., 15., 16.]], dtype=float32)>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t + 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 3), dtype=float32, numpy=\n",
       "array([[ 1.,  4.,  9.],\n",
       "       [16., 25., 36.]], dtype=float32)>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.square(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[14., 32.],\n",
       "       [32., 77.]], dtype=float32)>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t @ tf.transpose(t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using `keras.backend`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 2), dtype=float32, numpy=\n",
       "array([[11., 26.],\n",
       "       [14., 35.],\n",
       "       [19., 46.]], dtype=float32)>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tensorflow import keras\n",
    "K = keras.backend\n",
    "K.square(K.transpose(t)) + 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### From/To NumPy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3,), dtype=float64, numpy=array([2., 4., 5.])>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = np.array([2., 4., 5.])\n",
    "tf.constant(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1., 2., 3.],\n",
       "       [4., 5., 6.]], dtype=float32)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1., 2., 3.],\n",
       "       [4., 5., 6.]], dtype=float32)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3,), dtype=float64, numpy=array([ 4., 16., 25.])>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.square(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 1.,  4.,  9.],\n",
       "       [16., 25., 36.]], dtype=float32)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.square(t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Conflicting Types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cannot compute AddV2 as input #1(zero-based) was expected to be a float tensor but is a int32 tensor [Op:AddV2] name: add/\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    tf.constant(2.0) + tf.constant(40)\n",
    "except tf.errors.InvalidArgumentError as ex:\n",
    "    print(ex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cannot compute AddV2 as input #1(zero-based) was expected to be a float tensor but is a double tensor [Op:AddV2] name: add/\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    tf.constant(2.0) + tf.constant(40., dtype=tf.float64)\n",
    "except tf.errors.InvalidArgumentError as ex:\n",
    "    print(ex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=42.0>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t2 = tf.constant(40., dtype=tf.float64)\n",
    "tf.constant(2.0) + tf.cast(t2, tf.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Strings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=string, numpy=b'hello world'>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.constant(b\"hello world\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=string, numpy=b'caf\\xc3\\xa9'>"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.constant(\"café\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 99,  97, 102, 233], dtype=int32)>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "u = tf.constant([ord(c) for c in \"café\"])\n",
    "u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=int32, numpy=4>"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = tf.strings.unicode_encode(u, \"UTF-8\")\n",
    "tf.strings.length(b, unit=\"UTF8_CHAR\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 99,  97, 102, 233], dtype=int32)>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.strings.unicode_decode(b, \"UTF-8\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### String arrays"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = tf.constant([\"Café\", \"Coffee\", \"caffè\", \"咖啡\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(4,), dtype=int32, numpy=array([4, 6, 5, 2], dtype=int32)>"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.strings.length(p, unit=\"UTF8_CHAR\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.RaggedTensor [[67, 97, 102, 233], [67, 111, 102, 102, 101, 101], [99, 97, 102, 102, 232], [21654, 21857]]>"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = tf.strings.unicode_decode(p, \"UTF8\")\n",
    "r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<tf.RaggedTensor [[67, 97, 102, 233], [67, 111, 102, 102, 101, 101], [99, 97, 102, 102, 232], [21654, 21857]]>\n"
     ]
    }
   ],
   "source": [
    "print(r)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ragged tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor([ 67 111 102 102 101 101], shape=(6,), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "print(r[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<tf.RaggedTensor [[67, 111, 102, 102, 101, 101], [99, 97, 102, 102, 232]]>\n"
     ]
    }
   ],
   "source": [
    "print(r[1:3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<tf.RaggedTensor [[67, 97, 102, 233], [67, 111, 102, 102, 101, 101], [99, 97, 102, 102, 232], [21654, 21857], [65, 66], [], [67]]>\n"
     ]
    }
   ],
   "source": [
    "r2 = tf.ragged.constant([[65, 66], [], [67]])\n",
    "print(tf.concat([r, r2], axis=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<tf.RaggedTensor [[67, 97, 102, 233, 68, 69, 70], [67, 111, 102, 102, 101, 101, 71], [99, 97, 102, 102, 232], [21654, 21857, 72, 73]]>\n"
     ]
    }
   ],
   "source": [
    "r3 = tf.ragged.constant([[68, 69, 70], [71], [], [72, 73]])\n",
    "print(tf.concat([r, r3], axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(4,), dtype=string, numpy=array([b'DEF', b'G', b'', b'HI'], dtype=object)>"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.strings.unicode_encode(r3, \"UTF-8\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(4, 6), dtype=int32, numpy=\n",
       "array([[   67,    97,   102,   233,     0,     0],\n",
       "       [   67,   111,   102,   102,   101,   101],\n",
       "       [   99,    97,   102,   102,   232,     0],\n",
       "       [21654, 21857,     0,     0,     0,     0]], dtype=int32)>"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r.to_tensor()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sparse tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "s = tf.SparseTensor(indices=[[0, 1], [1, 0], [2, 3]],\n",
    "                    values=[1., 2., 3.],\n",
    "                    dense_shape=[3, 4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SparseTensor(indices=tf.Tensor(\n",
      "[[0 1]\n",
      " [1 0]\n",
      " [2 3]], shape=(3, 2), dtype=int64), values=tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32), dense_shape=tf.Tensor([3 4], shape=(2,), dtype=int64))\n"
     ]
    }
   ],
   "source": [
    "print(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 4), dtype=float32, numpy=\n",
       "array([[0., 1., 0., 0.],\n",
       "       [2., 0., 0., 0.],\n",
       "       [0., 0., 0., 3.]], dtype=float32)>"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.sparse.to_dense(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2 = s * 2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "unsupported operand type(s) for +: 'SparseTensor' and 'float'\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    s3 = s + 1.\n",
    "except TypeError as ex:\n",
    "    print(ex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 2), dtype=float32, numpy=\n",
       "array([[ 30.,  40.],\n",
       "       [ 20.,  40.],\n",
       "       [210., 240.]], dtype=float32)>"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s4 = tf.constant([[10., 20.], [30., 40.], [50., 60.], [70., 80.]])\n",
    "tf.sparse.sparse_dense_matmul(s, s4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SparseTensor(indices=tf.Tensor(\n",
      "[[0 2]\n",
      " [0 1]], shape=(2, 2), dtype=int64), values=tf.Tensor([1. 2.], shape=(2,), dtype=float32), dense_shape=tf.Tensor([3 4], shape=(2,), dtype=int64))\n"
     ]
    }
   ],
   "source": [
    "s5 = tf.SparseTensor(indices=[[0, 2], [0, 1]],\n",
    "                     values=[1., 2.],\n",
    "                     dense_shape=[3, 4])\n",
    "print(s5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "indices[1] = [0,1] is out of order. Many sparse ops require sorted indices.\n",
      "    Use `tf.sparse.reorder` to create a correctly ordered copy.\n",
      "\n",
      " [Op:SparseToDense]\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    tf.sparse.to_dense(s5)\n",
    "except tf.errors.InvalidArgumentError as ex:\n",
    "    print(ex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 4), dtype=float32, numpy=\n",
       "array([[0., 2., 1., 0.],\n",
       "       [0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0.]], dtype=float32)>"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s6 = tf.sparse.reorder(s5)\n",
    "tf.sparse.to_dense(s6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 6), dtype=int32, numpy=\n",
       "array([[ 2,  3,  4,  5,  6,  7],\n",
       "       [ 0,  7,  9, 10,  0,  0]], dtype=int32)>"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "set1 = tf.constant([[2, 3, 5, 7], [7, 9, 0, 0]])\n",
    "set2 = tf.constant([[4, 5, 6], [9, 10, 0]])\n",
    "tf.sparse.to_dense(tf.sets.union(set1, set2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 3), dtype=int32, numpy=\n",
       "array([[2, 3, 7],\n",
       "       [7, 0, 0]], dtype=int32)>"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.sparse.to_dense(tf.sets.difference(set1, set2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=int32, numpy=\n",
       "array([[5, 0],\n",
       "       [0, 9]], dtype=int32)>"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.sparse.to_dense(tf.sets.intersection(set1, set2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "v = tf.Variable([[1., 2., 3.], [4., 5., 6.]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Variable 'UnreadVariable' shape=(2, 3) dtype=float32, numpy=\n",
       "array([[ 2.,  4.,  6.],\n",
       "       [ 8., 10., 12.]], dtype=float32)>"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.assign(2 * v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Variable 'UnreadVariable' shape=(2, 3) dtype=float32, numpy=\n",
       "array([[ 2., 42.,  6.],\n",
       "       [ 8., 10., 12.]], dtype=float32)>"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v[0, 1].assign(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Variable 'UnreadVariable' shape=(2, 3) dtype=float32, numpy=\n",
       "array([[ 2., 42.,  0.],\n",
       "       [ 8., 10.,  1.]], dtype=float32)>"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v[:, 2].assign([0., 1.])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'ResourceVariable' object does not support item assignment\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    v[1] = [7., 8., 9.]\n",
    "except TypeError as ex:\n",
    "    print(ex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Variable 'UnreadVariable' shape=(2, 3) dtype=float32, numpy=\n",
       "array([[100.,  42.,   0.],\n",
       "       [  8.,  10., 200.]], dtype=float32)>"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.scatter_nd_update(indices=[[0, 0], [1, 2]],\n",
    "                    updates=[100., 200.])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Variable 'UnreadVariable' shape=(2, 3) dtype=float32, numpy=\n",
       "array([[4., 5., 6.],\n",
       "       [1., 2., 3.]], dtype=float32)>"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sparse_delta = tf.IndexedSlices(values=[[1., 2., 3.], [4., 5., 6.]],\n",
    "                                indices=[1, 0])\n",
    "v.scatter_update(sparse_delta)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tensor Arrays"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "array = tf.TensorArray(dtype=tf.float32, size=3)\n",
    "array = array.write(0, tf.constant([1., 2.]))\n",
    "array = array.write(1, tf.constant([3., 10.]))\n",
    "array = array.write(2, tf.constant([5., 7.]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 3., 10.], dtype=float32)>"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "array.read(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 2), dtype=float32, numpy=\n",
       "array([[1., 2.],\n",
       "       [0., 0.],\n",
       "       [5., 7.]], dtype=float32)>"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "array.stack()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2., 3.], dtype=float32)>"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mean, variance = tf.nn.moments(array.stack(), axes=0)\n",
    "mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2,), dtype=float32, numpy=array([4.6666665, 8.666667 ], dtype=float32)>"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "variance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom loss function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by loading and preparing the California housing dataset. We first load it, then split it into a training set, a validation set and a test set, and finally we scale it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import fetch_california_housing\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "housing = fetch_california_housing()\n",
    "X_train_full, X_test, y_train_full, y_test = train_test_split(\n",
    "    housing.data, housing.target.reshape(-1, 1), random_state=42)\n",
    "X_train, X_valid, y_train, y_valid = train_test_split(\n",
    "    X_train_full, y_train_full, random_state=42)\n",
    "\n",
    "scaler = StandardScaler()\n",
    "X_train_scaled = scaler.fit_transform(X_train)\n",
    "X_valid_scaled = scaler.transform(X_valid)\n",
    "X_test_scaled = scaler.transform(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "def huber_fn(y_true, y_pred):\n",
    "    error = y_true - y_pred\n",
    "    is_small_error = tf.abs(error) < 1\n",
    "    squared_loss = tf.square(error) / 2\n",
    "    linear_loss  = tf.abs(error) - 0.5\n",
    "    return tf.where(is_small_error, squared_loss, linear_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAEDCAYAAAB0/A4MAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deZzN9ffA8dd7xsQMY8seIhmyS5GQJaKS9r1QiVS/LGlR2lGEUIg2pCyVULJkjbKU4kuK7PtuzIxZzPL+/XFmGMuYOzP33s9dzvPxuI+ZufOZ+zmfuTP33M/nfd7nbay1KKWUUsp7QpwOQCmllAo2mnyVUkopL9Pkq5RSSnmZJl+llFLKyzT5KqWUUl6myVcppZTyMk2+SvkQY0wLY4w1xpTw0v46G2PivLEvpdQZmnyVyiNjzHhjzI8XuP+a9ERayftRKaV8mSZfpYKAMeYSp2NQSp2hyVcpL7nQJWVjTKX0+645Z/PrjDFrjTGJxpg1xpgG5zzW9caYpcaYeGPMXmPMGGNM4UzfX5J+3xBjzGHg1xzE2c0Ys8UYcyr945MX+P7m9NgOG2PmGWPypX+vtjFmoTEmxhgTa4xZZ4xpmZPfk1LBQJOvUr5pCPAScA2wDZhtjIkASXDAfGAWUBe4C6gHfH7OYzwCGKAZ0NGVnRpj7gQ+AoYDtYARwGhjzG3p378GGAW8BVQDWgNzMz3E18B+oCFQH3gTSHT5qJUKEvmcDkCpANHuAoVLeXlz+461dh6AMeYxYA/wEPAp8AIw1Vo7NGNjY0x34C9jTClr7aH0u7dba5/P4X77AF9aaz9K/3pz+ln3S8APQEXgJDDLWhsL7ATWZfr5y4Eh1tp/07/eksP9KxUU9MxXKff4BTn7zHx7KA+PtyLjE2ttHLAeqJF+VwPgEWNMXMaNM5eVq2R6jDW52O9VnH+Jenmmff+MJNztxpivjDGdjDGRmbYdBnxqjFlkjHnVGFM9FzEoFfA0+SrlHvHW2i2Zb8jZamZp6R9NpvvCcrGvEOQMOHOirwtUBdZm2u5kLh4b4EJLnVmA9LPdq4H7gF1AX+BfY0y59O+/iSTqGcD1wP+MMY/nMg6lApYmX6W853D6x7KZ7quXxbbXZXxijCmIjL/+k37Xn0DNc5N9+i0hjzH+AzQ9576mwMaML6y1KdbaRdbavkAdoCDQPtP3/7PWjrTW3gp8BnTJY0xKBRwd81XKe7YAu4E3jTEvA5WAflls2y+9Snkf8DpwCilmAhgErDTGfAyMBWKB6sBt1tpueYzxfeAbY8wapKirHfAwUtSFMaY9cmn7F+AY0BKIBP4xxoQjhWLfADuA0kjiXpXHmJQKOJp8lfISa22yMeYBYDRSpLQWeAU4r0EH8DIwFKko/htob609mf44/zPG3AD0B5YCoUhF9PduiHGGMeb/kMKr4cj47tPW2h/SN4kG7kDeEEQAW4Eu1tpl6XOJiwETgDLA0fRj65PXuJQKNMbaCw3vKKWUUspTdMxXKaWU8rIcJV9jTNX0rjaTPBWQUkopFehyeuY7CvjdE4EopZRSwcLl5JteKBINLPRcOEoppVTgcyn5pjdsfxvIaas6pZRSSp3D1alG7wCfWWt3G2Oy3MgY0xXoClCgQIEGFStWzHuEPiotLY2QkKzfu1gLR4/mp0SJJC9G5R7ZHZu/C+Tj2717N9Zagvl/z9/54/ElJoaSP38axlx89ow/HltObN68+Yi1tqQr22abfI0x9ZCVS+pnt621dhwwDqBatWp206ZNrsTgl5YsWUKLFi2y3S4+HiIiPB+PO7l6bP4qkI+vRYsWREdHs3bt2uw39lOB/PyB/x3fypVQsyZERma/rb8dW04ZY3a6uq0rb0FaIJ14dhljDiAT5u82xvyZq+iCyPHjUKcOpKQ4HYlSSnnGhAmwb5/TUfgfVy47jwOmZPq6D5KMu3sioEBSrBhs2AD5tI+YUipAjRnjdAT+KdszX2ttvLX2QMYNiAMSrbWHs/tZBcnJ8MorMgaslFKB5Pbb4e+/nY7CP+X4nCx9yTDlokKFoEIFufQclpvF45RSykd98AEEcG2fRwVu2ZmPMAa6d4cDB5yORCml3Oe776BoUR1Wyy1Nvl6QkAC33ioflVIqEKxapcNpeaHvWbwgPBzWrZOzYKWU8nfJyTB4sNNR+Dc98/WS1FS4/36Z96uUUv7KWqhfH/budToS/6Znvl6SLx906aLjI0op/2YM/PYbFC7sdCT+Tc98vah1a1i9WsdJlFL+69139TXMHTT5etkHH8BhnSGtlPJDqalwySVQsKDTkfg/vQjqRcZIeb5SSvmjI0fgeV3bzi30zNfLrIU2bWDXLqcjUUop1yUlQfPmEBfndCSBQc98vcwYGDUKypd3OhKllHJd/vywcSME8IqAXqW/RgdERcE33+i0I6WUfzh1Cjp1kvm9yj00+Tpk40Y4dMjpKJRSyjV33SVnv8o99LKzQ956SxZbsFY7XymlfNv69dChg9NRBBY983XQrbfCmjVOR6GUUlk7dkyWRU1LczqSwKJnvg6aOlVWBVFKKV9VvDjMm+d0FIFHz3wdVLQofPopbNrkdCRKKXW+3bvh5pu1o5UneCz5RkfryvGuKFZMS/eVUr6pXDkYMkTrUlzxxRc5295jL/uHDhWgRw9pR6aydvfdUKYMxMQ4HYlSSp1x4gTMmgU1azodiW9LS4O+feHxx3P2cx495xo5Em6/HWJjPbkX/9evH8yd63QUSil1xsGDMiVSZS0+XpaKfe89CA3N2c96LPlWqBBP8eIwezY0ayZjB+rChg+H++5zOgqllBKpqVClCrz6qtOR+K4DB6BlS/j2W1le8aefcvbzHku+4eGprFol3ZzWrYNGjXRaTVaMgYkT4auvnI5EKaVg/nzo2NHpKHzXhg2S01avhkqVZH3jm27K2WN49LLzlVfCihXQogXs3y9nwDNmeHKP/uvaa+H6652OQimlpMJ5zBino/BNc+fKa/WuXXDddbByZe7GxT1eZ5sxR6xzZ0hIkBZlQ4Zo6fq5rrrqTONypZRyyqpV0oOgcGGnI/E9o0dLc6TYWBkqXLQISpfO3WN5ZZLLJZfA55/DwIGSdF94Abp10ybd5/r1V1iyxOkolFLBLCJCpkCqM1JToVcveOYZqW7u1w8mT4bw8Nw/ptc6XBkj5dhXXiljCZ98Atu2yWC1dnkS994rH7Xfs1LKCUeOSKFV7dpOR+I74uLgwQfhxx8hLExyV6dOeX9cr7d3uPdeObsrVQoWLoTGjSUJK/HTT9C9u9NRKKWC0TffwAcfOB2F79izR2qVfvxRrgb8/LN7Ei841Nu5USMZV2jfHv7+W76eOVMLjkCe6OuuczoKpVQw6t5d63Ey/Pkn3HYb7NsHVatKAo6Kct/jO9bYsFIlGeNs21YudbRqJdfQg11kpFzm+O47pyNRSgWTkSPlJEiHvOT30KyZJN4bbpBZO+5MvODwwgpFisi7ie7dISkJHnoI3n5b33mlpsrlDqWU8pZ27aBePaejcJa1MGwY3HmndK/q2FHmPF96qfv35XhL/3z5YNQoGWcwBt54Qw44KcnpyJxTuTL06CG9VZVSytNWr5YxzcsvdzoS5yQny4ng889LEu7fH8aPlymgnuB48gVJuj17yql+wYIwaRK0bi2Xo4PV9u1yKT7YrwIopTxv3jz47z+no3DOiRNSgzR2rCTbKVOktaYnL8H7RPLNcNttsGwZXHYZLF8uhUfButZt5coyzqDjL0opT0pLg9deC96C1x075Njnz4eSJWHxYlkswdN8KvkC1K8vldD168PWrZKAFy92OipnnDoFTz6pyzIqpTznxhtl1kkwWrlSZtts3Ag1akjuadzYO/v2ueQLcua7bBl06ADR0dKwOqcLFQeCggWlCCItzelIlFKBaupUSTzBZto0WZXo0CFo00Zm31Su7L39+2TyBUk806fL4HdKiixU3LdvcCUiY6QX9po1OvarlHK/wYNljDOYhreslVbH998PiYnQtassfevtTos+m3xBFiceMgQ+/lg+f+89+YXFxzsdmXcNHixrRyqllLukpEgiKljQ6Ui859QpeOyxM8VUQ4dKfgkL834sPp18M3TrBnPmyCob334rlwqCJRkZI1cAypRxOhKlVCA5dAheekmmewaDY8dkCHPCBFk8Yvp06N3bubN+v0i+INfkV6yQzlirV8sg+fr1TkflPffdB3/84XQUSqlAcOyYrNmbkuJ0JN7x339SvLt0KZQtC7/8Anfc4WxMLiVfY8wkY8x+Y0yMMWazMaaLpwO7kBo1pDrtuutkIeMmTWRh42Dw/vtw9dVOR6GUCgTFi8PatcFx1vvLL5Iz/vsP6taVk7cGDZyOyvUz33eBStbawkAHoL8xxpHwS5eWBYzvv18WNL71VlngONBVqiQrHgXrvGellHts2wZdugRHkdXEidKw6dgxaaKxfDmUL+90VMKl5Gut/dtam9Hw0abfqngsqmyEh8PXX8uCxmlpssBxz56BPx/2xAl5w6GUUrlVpgw88YTTUXhWRuOQTp2kbWSPHjBjBhQq5HRkZxjr4hwWY8xooDMQDvwF3GCtjTtnm65AV4CSJUs2mDZtmluDvZB580ozZEg1UlJCaNz4CP36/UNEhOezcFxcHIUceCZTUw1JSSEePUanjs1bAvn4evbsSWpqKh9++KHToXhMID9/4NnjO348jP37w6lRI8Yjj58dbzx3p06F8N571Vm8uBQhIZZnn/2PO+/c59F9ZmjZsuUaa+01Lm1srXX5BoQCTYF+QNjFto2KirLesnSptcWLWwvW1q1r7e7dnt/n4sWLPb+TC+jf39qhQz27D6eOzVsC+fiaN29u69at63QYHhXIz5+1nj2+FSvkNcQpnn7uDh60tnFjyQWRkdbOmePR3Z0H+MO6mE9zVO1srU211i4HygPdc/SWwINuuEEKsapWhXXroGFDaUwRiF58UcrjlVIqJ1JSpPDo1VedjsQzNm6UWTArVkDFitKxql07p6PKWm6nGuXDwTHfC6laVRJw8+awf78k5BkznI7K/cLC5I9q0CCnI1FK+ZMRI6RRUSD6+WfpybxjB1x7rfRorl3b6aguLtvka4wpZYx5wBhTyBgTaoxpCzwILPJ8eDlTvLisTNGpk3TBuusu6WASaK0Zr7gCWrRwOgqllD/p0UPWqw00n3wic5ZjYuDuu2HJEv9oSuTKma9FLjHvAY4DQ4Ce1tqZngwsty65RBZhGDBAkm6fPvDUU1LxFijKlpV3dcuWOR2JUsofTJgA//sfFCnidCTuk5oKL7wgvZlTU+Hll2WxhIgIpyNzTbZTrK21h4HmXojFbYyBV16BK6+Us+Bx42Ru2zffeL95tqccPQqffw7NmjkdiVLK1xUvHliJ9+RJeOQRGVrMl0/6M/vb9Cm/aS+ZG/fdJ2sBlyoFCxbIgsnbtzsdlXtUqCBn+IE+t1kplTebNkkzoio+VaWTe/v2SW3PjBlyMjVvnv8lXgjw5AtS3bdqFdSsCf/8c6YaLhAkJECtWsG3ypNSynXPPw9btzodhXusWyev4WvWSO3LihXQqpXTUeVOwCdfkNaMv/4qK1ocPiyrIk2Z4nRUeRceLuO+/jLGoZTyvh9/lNkg/m72bOnnv2ePfFy1CqpXdzqq3AuK5Asy3jF7thRfJSXBgw9C//7+XwldooQcx7FjTkeilPIl8fHQtGlgXBkbORI6dJCx3ocekmHEEiWcjipvgib5ggzMjx4NH3wgRVkZvT+TkrL/WV9WsWLwLA2mlHJNRIQUm/rzlbGUFHj2WZkmlZYGb74JkyZBgQJOR5Z3QZV8QZJuz54yWF+wIHz5pawVfPSo05HlXseO8od54oTTkSilfMHJk/DZZ7IMq7+KiZGz3VGjZArppEnwxhuBsxpT0CXfDB06yHhpuXLy8brr/Hu5voED4bffnI5CKeULjh2DI0ecjiL3du2SS+Zz5sjl5YUL4eGHnY7KvYI2+QLUry8LK9evD1u2SHuyJUucjip3Ro6ULi9KqeCWmAiXXgovveR0JLnz++/Sn3/9eqhWTdoGN23qdFTuF9TJF+Cyy+CXX+RM+PhxqYj+4guno8qdiRNh8GCno1BKOWn+fBkn9UfffSdzeA8elFkpK1YEzvzkcwV98gVZYHn6dFktKDkZHn9cOmSlpTkdWc60bi2xK6WCV4cOMHas01HkjLVy4nDPPdK/4IknYO5cKFbM6cg8R5NvutBQWYRhzBj5/N134f775Q/BX5QrJ5Xb33zjdCRKKSeMHAmzZsnqZ/7i1Cl48skzl8kHDZLFEi65xNm4PC3b3s7B5qmnpHPKvffCt9/KwP+sWVC6tNORuSYtTZbVUkoFn7Zt/WsazvHjcra7aJHEPWmSrEwUDPTM9wJuukkqhy+/XAqyGjWCDRucjso1FSrISh8HDzodiVLKm+bNk5OEyy93OhLXbN0qRa6LFkncS5cGT+IFTb5ZqllT2pc1agQ7d8qiDPPmOR2Va+Li4MYbpepRKRUc5s+XubH+4Ndfz0zvrF1bTnIaNnQ6Ku/S5HsRpUvLqkj33QexsbIyyJgxTkeVvUKFZO1Of7r8pJTKvZMnpWalYkWnI8ne11/LYghHjkC7drB8uX/E7W6afLMRHg6TJ8Orr8ryfU8/Db16+f5SfiEh0LmzfzcOUUplLzYW6tXz/Std1sKECZfz8MNSZPXMM/DDD1C4sNOROUOTrwtCQmTxgvHjpYpw+HB4/fVaxMU5HdnFPfusrOiklApckZGy1J4vX+lKSoJHH4Xx4ysTEgIjRsBHH0m//WClyTcHOnWCn3+WuWe//VaCZs1keStfdc01su7lP/84HYlSyhM2b4aXX/btxROOHJEeBF99BQUKpDJzJjz3nNNROU+Tbw41by7tzsqXj2ftWinI+vNPp6PK2rZtsG+f01EopTyhRAmZneGr/v1XXiOXL4fy5eHDD/+ifXuno/INmnxzISoKPvroT264QRJbs2Ywc6bTUV3YI49IcYOueKRUYNmyRVZja9XK6UgubNEimUq0bRtcfbXMHrnySh8fq/MiTb65VKRICvPny6Xo+Hi4804YNkyKCnzNjBnw/PNOR6GUcqe1a313IZjPP5eGH9HRcPvt0j+/XDmno/ItQTzcnXf588siDFWrQr9+kuA2b4YPP/St9m4dOshNKRUYTp2SzlC+Ji1N+uIPGiRf9+kD770nLXvV2fTMN4+MkWlIU6dKMh47VuYD+9Jl3tBQKXp46CH/WyxCKXW+Bx/0vbPe+HjpiTBokLzmjB0L77+viTcrmnzd5L775J+hZEmpiL7+eti+3emozihVCrp0kTcLSin/NmGC1Jr4igMHoEULWRKwcGGYMwe6dnU6Kt+mydeNrrtOigpq1ICNG6XKb8UKp6MSxkhhxrRp/rVSk1LqjJQU6NZNPveVM8r16+W17vffpa/AihXQpo3TUfk+Tb5uVrmyLMrQpg0cPiwLQk+d6nRUZ2zYoIsuKOXP2rSBggWdjkLMnQtNmsjqb5lPPlT2NPl6QJEiMHu2vENNSoIHHpAOWb5QCf3OO1J1GBvrdCRKqZw4eVLe2N9zj28MH40eLfUtsbGy9vmiRTK8pVyjyddDwsJkEYahQ+Uf5bXXpNdyUpLTkUkCnjTJ6SiUUjmxbZv0QnZaair07Cm9mdPSZKbH119LH3zlOp1q5EHGQO/eUKWKVBpPnCgL3U+fDpde6lxcb7wR3D1VlfI3iYlQq5ZUDzspLk4qrX/8UU4wPv0UOnZ0NiZ/pWe+XnD77bBsmVzu/eUXGRvZvNm5ePLlk5aYzzzjXAxKKdcNGSILujhpzx6psP7xRyheHBYs0MSbF5p8vSSjvVq9etIWrnFjWLrUuXiqV5epR0op39e375kqZyesWSOL3a9dK02FVq6EG25wLp5AoMnXi8qXlzPg226DY8ekanHCBGdiiYiAmjXlspEvFIIppS7suedg507nVi6aOVMS7f798nHFCknAKm80+XpZoULw/ffQqxckJ0sRVr9+znSeCg2Vy9/x8d7ft1LKNbfcIm/cvc1aKRi98055jchYUtXJepVAosnXAaGhsgjD6NHy+YABUsTg7eYXoaEweLDs99Qp7+5bKXVxKSnSFKdtW7jkEu/uOzkZuneX3szWylTJL77wfhyBTJOvg7p3l/nAkZHyT9aypTMNMHr2hF9/9f5+lVJZO3xYxla97cQJmb87dqz0q586VfrX+8Lc4kCiyddhbdvKxPnLL5eCrEaN4O+/vRvDxImS+JVSviEuDooVkytk3kx627dLX/qff5Y+9UuWSN965X6afH1ArVpnEu/OnfLHP2+e9/YfEiJn4L17e2+fSqmsTZsm8/G9acUKeQ3auFFaRK5aJdMilWdkm3yNMfmNMZ8ZY3YaY2KNMX8ZY272RnDBpHRpWLwY7r0XYmLkss/HH3tv/02byviOUsp5jz8u46zeMnWqXP06fFhmYfz2m/SpV57jyplvPmA30BwoArwGTDPGVPJcWMEpPBymTJHFqFNTZUy4d2/53NOKFJGx54EDdeqRUk4aPLgaa9dKBylPs1YKPh94QFrfdusmV8GKFPH8voNdtsnXWnvSWvumtXaHtTbNWvsjsB1o4Pnwgk9IiPwzfPGF/PN98AHcdZeMAXlaRIQUWKSkaGWFUk65++49XlkZKCnpzFRHY2Ra0Zgx3kn6KhdjvsaY0kAU4OWyoODSuTPMny9FF7NmSVu3PXs8u8/QUHj+edi/P1zn/irlZUlJkgArVTrp8Sk9R4/CTTdJsWVEhPQe6N1bK5q9KUft9Y0xYcBXwARr7b8X+H5XoCtAyZIlWbJkiTti9ElxcXFeOb4RI8Lp27c2a9dGUL9+EgMHrqdqVc+eBk+eXJm4uDXUqBGY6w5667lzQnR0NKmpqQF7fBC4z19MTD62bi1DtWqePb49e+Q1Zc+eCEqUSGLAgPUUKRKHN36lgfrc5Yq11qUbcpY8BfgJCMtu+6ioKBvIFi9e7LV9HT5sbbNm1oK1ERHWzpzp2f1lHFtSkmf34xRvPnfe1rx5c1u3bl2nw/CoQHz+Dhyw9uBB+dyTx7dkibXFi8trSd261u7e7bFdXVAgPneZAX9YF3OqS5edjTEG+AwoDdxtrU320HsBdQElSsi8u0cflTZvd9whY8GeLIyaMQOeftpzj6+UOmPhQvjkE8/uY+JEqWQ+dgzat4fly51pW6mEq5edxwBXAa2ttV5ugqhACqEmTICoKHjtNRmf2bwZPvzQM2vz3nKLjAkppTwrJUXW+/aUtDSZM5wxdalnT1miMDTUc/tU2XNlnu/lQDegHnDAGBOXfnvY49GpsxgjlYmTJ0sy/vhjmQ984oT795VR8PHww97vOa1UMGnbFtav98xjJyRIYu/fX2ZSjBolV8008Tov23Mma+1OQGvgfMgDD0g7yttvl4roJk1kgetKldy7n4gIudStzdSV8pwpU2Royd0OHZLXiJUrz/SPb9fO/ftRuaPtJf1U48bS/u2qq6QXdKNGnmnC3q6d9Hfdts39j61UMNu9G555RhKvu6f4bNx45jWhYkVZOEUTr2/R5OvHKleWNnBt2si73JYt5d2tu+3Y4cxqS0oFsuLFZa1cdyfen3+WN+c7dsC118qb9Nq13bsPlXeafP1c0aLSDq5rV0hMhPvvd3+LyCeekHfR27e77zGVCmbLl8OuXdC6tXsfd9w4uPlm6Q9/zz1y1apMGffuQ7mHJt8AEBYmxVdDh8q76Fdfhcceg1On3LePtWt14QWl3GXXLti/332Pl5oq/5/dusnnL78siyVERLhvH8q9PDBJRTnBGJl+dMUVUqE8YYJcdvruO7j00rw//tVXw7ffyj+2VkoqlXvbt7t3atHJk/I/P3OmTDscO1ZWRVK+Tc98A8wdd8CyZVCuHCxdKmM///3nvsdv1UrmFyulcu7UKVmcPjraPY+3bx/ccIMk3qJFZfaDJl7/oMk3AF19tRRZ1K0rife66yQR55UxUtAVFZX3x1Iq2KSlyZnp6tWSKPNq7Vpo2BD+/BOqVIEVK6ToUvkHTb4Bqnx5Kepo317aybVpI+3l8qp0aZg7VybrK6VcN22arBrmjurmH3+Epk1h716Z579yJVSvnvfHVd6jyTeAFSokPZp79oTkZOjUSTpkpaXl7XGrV5d/fKWU6+69V4oh88JaGDFCmmdkjPUuXOiZJh3KszT5BrjQUGknN2qUfD5ggBR75KVlZKVK0tzj0089u7iDUoHAWmmmsX173pJkSgr83//Jm+m0NHjrLfjyS2k1q/yPJt8g8fTTcqkqMlKmILRqJY05ciskBLZskVWWlFJZM0aaaVSsmPvHiImB226TN9GXXAJffQWvv+7+Bh3KezT5BpF27aTNXMWKMkbUqJG0psyNfPngvfcgLg4OHHBvnEoFiqNHZdpf69a575G+c6eM686dK2fOixZ5dhUk5R2afINM7dpSCd2wocwDvv56aUeXWxMm5O3nlQpk0dFy1ppbv/8ub5I3bJBai5UrJREr/6fJNwiVKQOLF0v7uZgYaUc3dmzuHuvFF2Xlo5Mn3RujUv5u/XqZb/9//5e7n//uO2jeXPqqt2olfdyrVHFvjMo5mnyDVESEjP327Stdq556SqZBpKbm/LEOHZK5xCkp7o9TKX/1ySfw1185/zlrYdAgeXOckCC91efOhWLF3B+jco62lwxiISGyCEPVqrIww7BhsHUrdOuWs/dkpUrJ5bB8+tekFCCXm0eOzPnPnToFQ4ZU46ef5OtBg+CFF7SwKhDpma/iscekLV2xYtKmrkeP+uzdm7PHKFhQqi+nT/dMjEr5i02bpLlNTqfhHT8uRZE//VSW8HC57Pzii5p4A5UmXwVIW7oVK2RM6b//ImnUKOeXzDp1grZtPROfUv7AWqhWTRpf5CRpbt0qfdgXL4bixZNYuhTuustzcSrnafJVp1WrJpeP69SJZu9eaNYMfvjB9Z+vUkWmHr34ojbfUMHp6adh3rycNb5YvlwqmjdtktkIo0f/ybXXei5G5Rs0+aqzlCgB77+/jkcekQrm22+H4cNdT6bFi0O9epSae5EAAB2TSURBVJ6NUSlf9dprssqQq776Cm68UeYD33yzJOLSpZM8F6DyGZp81XkuucQycSK8/bYk3V69pD2eK9XMYWHSAGDBAumApVQw2LMHuneHsmUhPDz77a2FN9+ERx6RIqtnn4VZs6BwYY+HqnyEJl91QcbIu/ivv5ZLaGPGSBHJiROu/fyePXD4sGdjVMpXXHqptJB0ZZw3MVGS7ltvyYyDkSPhww91tkCw0eSrLurBB6WdXcmSMpbVpIl0xsrOY4/JONbvv3s8RKUc9fXXsHs33HRT9tsePiytJr/+WlYdmzUr9004lH9z7L1WTEwMhw4dIjk52akQ8qRIkSL8888/TofhEeceW6lSYSxbVoo77yzM339LUp01Sz5ezKFD0L+/TD8KDfVw0Eo5JDFRhluy8++/cOutsG2brLf9449Qt67n41O+yZHkGxMTw8GDB7nssssIDw/H+OFEttjYWCIjI50OwyMyH5u1loSEBPbu3cuCBdCpU2EWLIAWLWDiRFmjNCtlysi84bg4GeMK0F+XClKJibJQyeOPZ7/tokVw993SfKNBA5lFULas52NUvsuRy86HDh3isssuIyIiwi8TbzAxxhAREcFll11GfPwhfvoJnnxSXnjuuw/efTf7SuhBg2DaNO/Eq5S37NwpZ6/Z+ewzmf8eHQ133AFLl2riVQ4l3+TkZMJdKQlUPiM8PJzk5GTCwmQRhiFDpLjklVfknf+pU1n/7JtvSn/ai22jlD/ZvFnasn7wQdbbpKXByy9Dly4yU+CFF6RrVcGC3otT+S7HCq70jNe/ZH6+jJFFGKZPlwUaxo+XYpNjxy78s6Ghcum5fn1d/UgFhldflWX+shIfL1eGBg2Sv/9x42DwYKluVgq02lnlwR13wC+/yCW0pUulPd5//11420KFpIFAwYLa/Ur5r5QUWYZz2jSoU+fC2xw4IDUR330HRYrIikRPPunVMJUf0OSr8qRBA1i9Wqo2N2+WpQWXLbvwtsWKyRSLt9/2boxKuctPP0Hv3lnP512//swUu8qVZQ3e1q29G6PyD5p8VZ6VLy8J99Zb5dLzjTfCl19eeNt27WTtYKX8TUoKdOgAo0df+Ptz5sg8+F275CrQypVQo4Z3Y1T+Q5NvDrVo0YJnn33W8cfIzvHjxyldujRbt27Ndtt77rmHYcOG5Wl/kZEZyxFCcjJ07ChLDJ57ibl4cVn/97HHpFpUKX+QkgLXXgtHjsAll5z//VGjpANcbCw88IBMLSpVyvtxKv+hyTdADRw4kFtuuYUqVapku+0bb7xB//79OeFq78gshIbKIgwffSSFJe+8I32eExPP3s4YSb7lyuVpd0p5hbXS+nHePFl4JLPUVOjZU3ozp6VJS9avvoICBZyJVfkPTb4B5FT6XJ74+Hg+/fRTnnjiCZd+rnbt2lxxxRVMmjTJLXE884zMf4yMhClToFUr6XaV2Q03yJnvwIFu2aVSHjNggFT0n3smGxsrq36NGCEdrjIWI9GKZuUK/TPJhbS0NN566y1KlChBqVKl6NOnD2lpacCFLyl37tyZ9u3bn3VfSkoKPXr0oFixYhQrVowXXnjh9GOAdJYaPHgwVapUITw8nNq1a5+XHFu0aEH37t3p06cPJUuWpEmTJgD89NNPhISEnP4aYPDgwRhjzru9/vrrAHTo0IHJkye77Xd0883S/adiRVixQopQNm48e5uSJWWupFK+7KmnZKw3sz17ZL3r2bNlKGXBAnj0UWfiU/7JJ5KvMc7ccuurr74iNDSU3377jY8++ojhw4czderUHD9GWloaK1asYOzYsYwbN47hw4ef/n6/fv347LPPGDVqFBs3bqRv375069aN2bNnn/U4kyZNwlrLsmXLmDhxIgDLli2jQYMGZ83N7d69O/v37z99e/755ylTpgwdO3YEoGHDhqxevZqEhITc/lrOU7s2rFolY2U7dkgRys8/n/l+kSLSnnL6dFi71m27Vcottm2Dhx+WFYuKFz9z/5o10LAhrFsHUVFSWJWTNXyVAgcXVvBnNWrUoF+/fkRGRhIVFcUnn3zCwoULefDBB11+jLJlyzJy5EiMMVSvXp3NmzczbNgwevfuzcmTJxk2bBjz58+nWbNmAFSuXJnVq1czatQobr311tOPU7lyZYYOHXrWY+/cuZOy5/Svi4yMPN2vedCgQUyePJklS5Zw5ZVXAlCuXDmSk5PZt28fpdxYKVKmDCxZIgVY330nZ8SjR0PXrme2CQ3Vub/K91SsKAWEmd+oz5ghCTk+/sxc3syJWSlX+cSZr7XO3HKrzjmz68uVK8ehcwc1s3HdddeddWbauHFj9u7dS0xMDBs3biQxMZF27dpRqFCh07cxY8acV73coEGD8x47ISGBAllUfLz77ruMHDmSxYsXU61atdP3Z7T7dOeZb4aICGlK8PLLUqDSrRv06SOfg4yb1akDn3wiVaVKOclamcu7c6ec4WbcN3Qo3HWXJN5OnaQASxOvyi2XznyNMc8CnYHawGRrbWcPxuTzws5ZP8wYc3q8NiQkBHtOZs/psokZj/XDDz9QsWLFi+674AUaxZYoUYLjx4+fd/+AAQP4+OOPWbp06ekz3gzH0ntDlixZMkexuiokRBZhqFpVku/QobB1K0yadKbr1Y4d0n6ySBGPhKCUS4yBNm3gssvk6+RkqWYeN06+HjhQ3khqh1yVF66e+e4D+gOfezCWgFCyZEn2799/1n3r1q07b7tVq1adlaRXrlxJuXLlKFy4MDVq1CB//vzs3LmTK6+88qzb5Zdfnm0M9evXZ+M51U3vvPMOY8eOPetSc2YbNmygXLlylC5d2tVDzZXHH4f586FoUbmEd8MNsG+fTOUYMEDOfBcu9GgISmVpxgypQbj5ZpkuFB0Nt9wiibdAAbmC07evJl6Vdy4lX2vtdGvtDOCoh+Pxe61atWLOnDnMmjWLTZs20bt3b3bv3n3edvv27aNnz55s2rSJb7/9lvfff59evXoBMj7bp08f+vTpw+eff86WLVtYu3YtH3/8MeMy3n5fRNu2bfnnn384elSergEDBjBixAimTJlCwYIFOXDgAAcOHCAx0wTcZcuW0a5dOzf9Fi6uZUspUqlSBf78Uy7tZRRc7d0rPaCVcsIVV0ClSvL59u3SsWrBAplmtGTJxdevVion3FpwZYzpCnQFOQNcsmTJBbcrUqQIsbGx7ty116SmpnLq1ClSU1NPH0NycjIpKSnExsZy77338scff/DYY48B0KVLF9q3b8/Ro0dPb5+amsp9991HQkICjRo1whjDo48+SpcuXU5v8+KLL1KkSBEGDx5M9+7diYyMpE6dOvTo0eOsxzl16tR5v8tKlSrRoEEDxo8fz5NPPsngwYOJiYk5a+oRwKxZs2jRogWJiYl8//33TJ8+ndjY2LOOLbPExMQsn9PcGDo0jNdeq8n69UVp3DiV11/fSOPGR2neHD79tBCFCydTqlSS2/aXIS4uzq3H4Uuio6NJTU0N2OMDzzx/0dFhzJxZjo4dd2IMjBpVmH79ahEdfQmVKp3k3XfXk5CQiDd+rYH89xnIx5Zj1lqXb8il5/GubBsVFWWzsnHjxiy/5y9iYmKcDuGi5syZY6OiomxKSkq223700Ue2TZs2p7/O6tg88bwlJlr7yCNSAhcSYu3w4dampVk7cqS1c+a4fXfWWmsXL17smQf2Ac2bN7d169Z1OgyP8sTzd+yYtRMmyOdTplibP7/8Td50k7XR0W7f3UUF8t9nIB+btdYCf1gX86lPVDsr92vXrh3PPPMMe/bsyXbbsLAwPvzwQy9Edb78+aUz0FtvSXu+jFZ93bvLIgwLF56pilbK3ayVQkBrpUlG//7SmzkpSZprzJ6tBYDKM3SebwB77rnnXNqua+ZJtw4wRhZhqFoVOneWecDbtsHkyVINXb36mcpTpdzJWmmDaoz87U2cKJ8PG3b+HF+l3MnVqUb50rcNBUKNMQWAFGutzspUbvPgg9LY4I47ZAHyZs2kR3Tp0lLs0qKF0xGqQPLVV3DVVfJ3d8cd8MsvMid98uTz20kq5W6uXnbuByQALwOPpH/ez1NBqeDVpIm0pKxeHTZskJ7Qc+ZIY3vtgqXcqVAhWfDjuusk8ZYrJ+tSa+JV3uDqVKM3rbXmnNubHo5NBakrrpDFGG68EQ4ehPvuk7VSDx+G3393Ojrl7/73P7m8XLSotIrcsgXq14fVq+Hqq52OTgULLbhSPqloUTnj7dJF1gO+91549VVYvNjpyJS/K1BAFkdo0waOHYPbbpMzX60rUN6kyVf5rLAw6Sw0eLAUvnz6Kfz7rzThcKGIW6mzHDgAr7wCEybAyJHSNrJXL/j+e7kErZQ3afJVPs0YeOEFWT0mPBy++AIeewz++MPpyJS/yZdPrpwMHCgraY0eLVXNoaFOR6aCkSZf5RfuvFMuDZYpI2N0L70kL5wnTjgdmfJ1SUmypGW7dtLWNDJS5u927+50ZCqYafJVfuOaa6Qopk4d2LwZ+vWDn392Oirl6/77T6aurVkDl18Ov/0Gbds6HZUKdpp8lV+pUEHGfG+5BRIS4KGHpBraA8sQqwBw880ylejwYZm2tmoV1KrldFRKafLNsQ4dOlCsWDEeffRRp0MJWpGRMHMmPPecFM188420BdR5wCqDtTB2rCxfefKkVMsvXiwNW5TyBZp8c6hXr15MnDgxxz+3e/duWrRoQY0aNahbty7Tp0/3QHTBI18+GDECPvwQQkKkiKZ6dXmhVcEtNVXm6z71lPQLf+UVmDJFCvaU8hWafHOoZcuWREZG5vjn8uXLx/Dhw9m4cSM///wzPXr0ID4+3gMRBpdnn4UffoCCBWUc+Kab5BKjCk4nTsBdd8n60PnySXX8gAHyBk0pX6J/kl5StmxZ6tWrB0CpUqUoVqwYR44ccTiqwHDLLVJEU6GCfKxcGf75x+molLft3Su9mmfNgmLFpBivc2eno1LqwjT5OuCPP/4gOTmZChUqOB1KwKhTR4pprrlGLj03bgzz5jkdlfKWFSvkud+/H6pUkSlFuhCH8mWafL3s6NGjdOzYkc8++wyj65W5VdmysHQp3H23XH68+Wb45BOno1Ke9sMPkmgPHJCVsFauhKgop6NS6uI0+brR4MGDMcacd3v99dcBSEpK4s4776Rv375cf/31DkcbmCIiYNo0ePFFqXjt2hWefloKb1RgsVaq3Dt0gFOn4JFH5FJziRJOR6ZU9jT55lDr1q259957mT9/PuXLl2fFihWnv9e9e3f2799/+vb8889TpkwZOnbsiLWWzp0706pVK52m5GEhITBokPSCDgmBMWPkbFgroQNHSooU2732mnz9zjuyUlH+/M7GpZSr8jkdgL9ZsGABALGxsedVPUdGRp6+b9CgQUyePJklS5Zw5ZVXsnz5cqZOnUqdOnWYMWMGAF9++SW1a9f27gEEkSeekOKru+6CGTNkTHDhQqejUnl18mQoLVrAr79Ksv3iC3jwQaejUipnNPl6wLvvvstHH33E4sWLiUoffGratClpeu3T61q1kjHAVq1kRaRGjeCNNwpqMY6f2rkTnn22Pjt2QOHCsuykjuAof+Qzl53ffFNuIMUSmzdLL9YGDeS+55+HoUPl83LlYN8+WLLkTEVj166y/BxIB6TYWCnEuO02ue+hh+Drr+Xz3NY5ZR7HLVy48HljuwADBgxg9OjRLF269HTiVc6qXl0WUG/SRJYifPrpq5k92+moVE6tXg316sGOHYW46ir46y9NvMqPWWs9couKirJZ2bhxY5bf82W7du2yzZs3t1dddZWtVauW/e677876/ttvv20rVKhgt2zZ4lCE7hETE3PB+/31ecuQkGDtEzftsuXZZY2xdtgwpyNyv+bNm9u6des6HYbbjR9vbf781pZnl70xaoM9ftzpiDxn8eLFTofgMYF8bNZaC/xhXcyRPnPm6w8yd6maOXPmWV2qBgwYwIgRI5gyZQoFCxbkwIEDHDhwgMTERIejVhkKFIBP5lagdecUrIXevaVoJyXF6chUVlJS5Hnq3FmWBrylawVeHnWEokWdjkypvNHkmwOZu1SVLFnydJcqay2DBw/m6NGjNGnShLJly56+/frrrw5HrTIz06bSp8IXTJwIYWEwapSserN/v9ORqXMdOQJNm8IHH0iryFGjYGyrqZT9RavmlP/Tgqtc+vPPP093qTLGcEJXdfcPY8ZwWXQ0Nde+TZUqMkd0zRqoXx+mToXmzZ0OUAH8+adMD8sorJo9WxIxLeT54+23nQ5RqTzRM99cOHr0KN26ddMuVX7u+uvh77+hZUs4eFA+Dh6sSxM6yVpZreraayXxXnutPEdNmzodmVLupck3hzK6VPXu3Vu7VAWA0qVlzddeveSF/6WXpC3loUNORxZ8jh6F1q2hZ0/pSPbEE/DLL1C+vNORKeV+mnxzwGbqUvWgzuoPGPnywbBhMHOmXOKcNw9q1UKnI3nRsmUyjWjRIlke8rvvpENZgQJOR6aUZ2jyzYFff/2VqVOnMmPGDJo0aUK9evVYv36902EpN+nQAdavl+b8hw9D+/Zy9qXLLntOQoLM4b/hBpmD3bgxbNggXcmUCmRacJUDmbtUXai9pPID337L37/+SpMsvl2xojRvef99eOUV+PxzufT5xRc67uhuv/0GDz8sY7vGSBIeOFCq0LOUzfOnlL/QM18VXEqUILlIkYtuEhIiY79//gk1a8KWLXI2/NRTEB3tpTgDWHw89OkjHcd27IBq1WQt5vffzybxgkvPn1L+QJOvCi7jx1Nm7lyXNq1bF/74A/r2lXHhsWPhiitkPFIronPOWhlXr1pVWsUaI29y1q6VqmaX5OD5U8qXafJVwSWHL94FCsil0LVroU4dOH4c7rlHeor/84/nwgw0W7bIHOo77pC+7FdeKQtevPdeDouqNPmqAKHJVykX1KwpjfzHjJFq3F9+gdq1pT3lkSNOR+e7TpyAF1+EGjWkojk8HEaOlDcuDRs6HZ1SztHkq5SLQkJk3HfbNujWTeaijhoFVarIZdSEBKcj9B0JCfI7ueIKGctNToZOnWD7dvi//5PL+EoFM8eSr9VBM7+iz9cZpUrBxx/DunVSiBUTIwVE5cvD8OHBnYSTk+Gzz+QNSZ8+cOyYFFb99huMHy9NTZRSDiXfsLAwEoL5FcoPJSQkEJZtKWpwqV0bli6VZhxVq0qi6dULypSRph0nTzodoffExckbj8sugy5dZKGKqChZ7H7ZMpm/q5Q6w5HkW6pUKfbu3Ut8fLyeUfk4ay3x8fHs3buXUqVKOR1O3v30E/977z23PZwxcMstsGkTzJolY5sxMTJntUwZ6NFDptMEqoMHoV8/OaPt1Uuak1SoAJMny7huu3byO3IbNz9/SjnFkZGXwoULA7Bv3z6Sk5OdCCHPEhMTKRCgve/OPbawsDBKly59+nnzaxERpHngeTMGbrtNumLNng0DBkg178iR8OGH0KoVvPyyfAzx80qLtDRYsECObe5cSE2V+6+5Bl57TX4HHjtGDz1/SnmbY2UPhQsX9usX8yVLllC/fn2nw/CIQD42Ro+m3ObNMlfIA4yR5NO+PaxeLYl38mRYuFBuJUtCx46yOHytWh4JwWO2bIEpUyTpHj4s9xkDt98OL7wgY7se5+HnTylv0ZpDFVymTaOUl9pUNWwIX34p1b7jxkkh0q5dUgU8dKjMde3USS5b16/v5suzbrJxI3z7rbyB+PffM/eXKwfdu8Njj8k4r9d48flTypNcujhkjClujPneGHPSGLPTGPOQpwNTKlCUKQOvvy7TbJYtk2lKhQrJmeRrr0GDBlC2rCzi8M030oTCKQcOSKJ94gl5c1CzJrzxhiTeAgWkF/O8ebB7t4z1ejXxKhVAXD3zHQWcAkoD9YDZxph11tq/PRaZUgEmJEQWZ2jaVBaMnzcPfvxRziwPHpRFHD7/XLYtVUou4zZvLp21qleXJO7Os+OjR2W61F9/SQevxYth796zt4mMlLHshx+GG2+E/Pndt3+lglm2ydcYUxC4G6hlrY0DlhtjZgGPAi97OD6lAlL+/LKEYYcO0jP6f/+TaukFCyQRHjoE338vtwwFC0LlyjJWXLSozCsuXVrGkcPDpXHFiRNw8mQ+li+H2NgztxMnpDnInj2wcyds3Sr3nys8XN4ctGgBbdrI5XBtiKGU+5nspvoYY+oDv1lrwzPd1wdobq29Laufi4iIsA0DuH9cdHQ0RYsWdToMjwjkY2PtWlJSUsh3zTVOR5Ila2Xln5gYSZpxcZCUBCkprvz02vSP9bLdMiQEIiIkqRcuLJfCIyN9c+z5ND94/vIqkP//AvnYAJYuXbrGWuvSH6crybcZ8I21tkym+54EHrbWtjhn265A1/QvawEbchC3vykBBGpX30A+NtDj83d6fP4rkI8NoJq11qWF3l25oBQHnDsnqDBw3kUra+04YByAMeYPV98B+KNAPr5APjbQ4/N3enz+K5CPDeT4XN3WlWrnzUA+Y0zVTPfVBbTYSimllMqFbJOvtfYkMB142xhT0BjTBLgd+NLTwSmllFKByNUmcE8D4cAhYDLQ3YVpRuPyEpgfCOTjC+RjAz0+f6fH578C+dggB8eXbcGVUkoppdzLz1u8K6WUUv5Hk69SSinlZV5JvsaYqsaYRGPMJG/sz1uMMZOMMfuNMTHGmM3GmC5Ox+Quxpj8xpjP0nt5xxpj/jLG3Ox0XO5kjHnWGPOHMSbJGDPe6XjyKpB7sAfac3WuQP9/C+TXysxykuu81ThuFPC7l/blTe8CT1hrk4wx1YElxpi/rLVrnA7MDfIBu4HmwC7gFmCaMaa2tXaHk4G50T6gP9AWKSj0d4Hcgz3QnqtzBfr/WyC/Vmbmcq7z+JmvMeYBIBpY6Ol9eZu19m9rbVLGl+m3Kg6G5DbW2pPW2jettTustWnW2h+B7UADp2NzF2vtdGvtDOCo07HkVaYe7K9Za+OstcuBjB7sfi+QnqsLCfT/t0B+rcyQ01zn0eRrjCkMvA0878n9OMkYM9oYEw/8C+wHfnI4JI8wxpQGotDmKr4qCki11m7OdN86oKZD8ag8CMT/t0B+rcxNrvP0me87wGfW2t0e3o9jrLVPA5FAM6QZSdLFf8L/GGPCgK+ACdbaf7PbXjmiEHDinPtOIH+byo8E6v9bgL9W5jjX5Tr5GmOWGGNsFrflxph6QGvgg9zuw0nZHV/mba21qemX+coD3Z2JOGdcPT5jTAjSzewU8KxjAedQTp6/AOFyD3blu/z1/81V/vhamZ3c5rpcF1ydu6LRBQLqCVQCdhlZo6wQEGqMqWGtvTq3+/WW7I4vC/nwk3EMV47PyBP3GVLAc4u1NtnTcblLLp8/f3a6B7u19r/0+7QHux/x5/+3XPCb10oXtCAXuc6Tl53HIb/ceum3j4HZSLWi3zPGlDLGPGCMKWSMCTXGtAUeBBY5HZsbjQGuAm6z1iY4HYy7GWPyGWMKAKHIP0sBY4xfLh0f6D3YA+m5uoiA/H8LgtfKXOU6jyVfa228tfZAxg25LJZorT3sqX16mUUum+wBjgNDgJ7W2pmORuUmxpjLgW7IH9MBY0xc+u1hh0Nzp35AAvAy8Ej65/0cjShvctOD3V8E2nN1lgD/fwvo18rc5jrt7ayUUkp5mbaXVEoppbxMk69SSinlZZp8lVJKKS/T5KuUUkp5mSZfpZRSyss0+SqllFJepslXKaWU8jJNvkoppZSXafJVSimlvEyTr1IBwBjzYhYrOL3tdGxKqfNpe0mlAoAxJhIomOmuPsDDQDNr7RZnolJKZUWTr1IBxhjzEvAc0Mpau8npeJRS5wu0JbmUCmrGmL7IIuwtrbWbnY5HKXVhmnyVChDGmFeBp4DmeqlZKd+myVepAGCMeQ14Emhhrd3qdDxKqYvT5KuUn0s/4+0BdABOGmPKpH8r2lqb6FxkSqmsaMGVUn7MGGOAaKDwBb7d2lq70MshKaVcoMlXKaWU8jJtsqGUUkp5mSZfpZRSyss0+SqllFJepslXKaWU8jJNvkoppZSXafJVSimlvEyTr1JKKeVlmnyVUkopL9Pkq5RSSnnZ/wPQss9uW/1k+QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 576x252 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(8, 3.5))\n",
    "z = np.linspace(-4, 4, 200)\n",
    "plt.plot(z, huber_fn(0, z), \"b-\", linewidth=2, label=\"huber($z$)\")\n",
    "plt.plot(z, z**2 / 2, \"b:\", linewidth=1, label=r\"$\\frac{1}{2}z^2$\")\n",
    "plt.plot([-1, -1], [0, huber_fn(0., -1.)], \"r--\")\n",
    "plt.plot([1, 1], [0, huber_fn(0., 1.)], \"r--\")\n",
    "plt.gca().axhline(y=0, color='k')\n",
    "plt.gca().axvline(x=0, color='k')\n",
    "plt.axis([-4, 4, 0, 4])\n",
    "plt.grid(True)\n",
    "plt.xlabel(\"$z$\")\n",
    "plt.legend(fontsize=14)\n",
    "plt.title(\"Huber loss\", fontsize=14)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_shape = X_train.shape[1:]\n",
    "\n",
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation=\"selu\", kernel_initializer=\"lecun_normal\",\n",
    "                       input_shape=input_shape),\n",
    "    keras.layers.Dense(1),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=huber_fn, optimizer=\"nadam\", metrics=[\"mae\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples, validate on 3870 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 73us/sample - loss: 0.6247 - mae: 0.9969 - val_loss: 0.2884 - val_mae: 0.5913\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 32us/sample - loss: 0.2193 - mae: 0.5176 - val_loss: 0.2361 - val_mae: 0.5232\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9fa6099f10>"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled, y_train, epochs=2,\n",
    "          validation_data=(X_valid_scaled, y_valid))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving/Loading Models with Custom Objects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(\"my_model_with_a_custom_loss.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.load_model(\"my_model_with_a_custom_loss.h5\",\n",
    "                                custom_objects={\"huber_fn\": huber_fn})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples, validate on 3870 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 55us/sample - loss: 0.2056 - mae: 0.4982 - val_loss: 0.2170 - val_mae: 0.5037\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 35us/sample - loss: 0.2006 - mae: 0.4911 - val_loss: 0.2097 - val_mae: 0.4908\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9f800da650>"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled, y_train, epochs=2,\n",
    "          validation_data=(X_valid_scaled, y_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_huber(threshold=1.0):\n",
    "    def huber_fn(y_true, y_pred):\n",
    "        error = y_true - y_pred\n",
    "        is_small_error = tf.abs(error) < threshold\n",
    "        squared_loss = tf.square(error) / 2\n",
    "        linear_loss  = threshold * tf.abs(error) - threshold**2 / 2\n",
    "        return tf.where(is_small_error, squared_loss, linear_loss)\n",
    "    return huber_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=create_huber(2.0), optimizer=\"nadam\", metrics=[\"mae\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples, validate on 3870 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 69us/sample - loss: 0.2229 - mae: 0.4893 - val_loss: 0.2525 - val_mae: 0.4973\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 32us/sample - loss: 0.2189 - mae: 0.4856 - val_loss: 0.2338 - val_mae: 0.4765\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9f50093a50>"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled, y_train, epochs=2,\n",
    "          validation_data=(X_valid_scaled, y_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(\"my_model_with_a_custom_loss_threshold_2.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.load_model(\"my_model_with_a_custom_loss_threshold_2.h5\",\n",
    "                                custom_objects={\"huber_fn\": create_huber(2.0)})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples, validate on 3870 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 58us/sample - loss: 0.2148 - mae: 0.4796 - val_loss: 0.2111 - val_mae: 0.4713\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 34us/sample - loss: 0.2123 - mae: 0.4775 - val_loss: 0.1970 - val_mae: 0.4534\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9fa61cce90>"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled, y_train, epochs=2,\n",
    "          validation_data=(X_valid_scaled, y_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HuberLoss(keras.losses.Loss):\n",
    "    def __init__(self, threshold=1.0, **kwargs):\n",
    "        self.threshold = threshold\n",
    "        super().__init__(**kwargs)\n",
    "    def call(self, y_true, y_pred):\n",
    "        error = y_true - y_pred\n",
    "        is_small_error = tf.abs(error) < self.threshold\n",
    "        squared_loss = tf.square(error) / 2\n",
    "        linear_loss  = self.threshold * tf.abs(error) - self.threshold**2 / 2\n",
    "        return tf.where(is_small_error, squared_loss, linear_loss)\n",
    "    def get_config(self):\n",
    "        base_config = super().get_config()\n",
    "        return {**base_config, \"threshold\": self.threshold}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation=\"selu\", kernel_initializer=\"lecun_normal\",\n",
    "                       input_shape=input_shape),\n",
    "    keras.layers.Dense(1),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=HuberLoss(2.), optimizer=\"nadam\", metrics=[\"mae\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples, validate on 3870 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 66us/sample - loss: 0.6934 - mae: 0.8818 - val_loss: 0.3288 - val_mae: 0.5451\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 32us/sample - loss: 0.2402 - mae: 0.5077 - val_loss: 0.2315 - val_mae: 0.4861\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9fb0debd10>"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled, y_train, epochs=2,\n",
    "          validation_data=(X_valid_scaled, y_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(\"my_model_with_a_custom_loss_class.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "#model = keras.models.load_model(\"my_model_with_a_custom_loss_class.h5\", # TODO: check PR #25956\n",
    "#                                custom_objects={\"HuberLoss\": HuberLoss})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples, validate on 3870 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 0s 33us/sample - loss: 0.2263 - mae: 0.4946 - val_loss: 0.2285 - val_mae: 0.4851\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 33us/sample - loss: 0.2217 - mae: 0.4913 - val_loss: 0.1994 - val_mae: 0.4635\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9fb0e4cf90>"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled, y_train, epochs=2,\n",
    "          validation_data=(X_valid_scaled, y_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "#model = keras.models.load_model(\"my_model_with_a_custom_loss_class.h5\",  # TODO: check PR #25956\n",
    "#                                custom_objects={\"HuberLoss\": HuberLoss})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.0"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.loss.threshold"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Other Custom Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_softplus(z): # return value is just tf.nn.softplus(z)\n",
    "    return tf.math.log(tf.exp(z) + 1.0)\n",
    "\n",
    "def my_glorot_initializer(shape, dtype=tf.float32):\n",
    "    stddev = tf.sqrt(2. / (shape[0] + shape[1]))\n",
    "    return tf.random.normal(shape, stddev=stddev, dtype=dtype)\n",
    "\n",
    "def my_l1_regularizer(weights):\n",
    "    return tf.reduce_sum(tf.abs(0.01 * weights))\n",
    "\n",
    "def my_positive_weights(weights): # return value is just tf.nn.relu(weights)\n",
    "    return tf.where(weights < 0., tf.zeros_like(weights), weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer = keras.layers.Dense(1, activation=my_softplus,\n",
    "                           kernel_initializer=my_glorot_initializer,\n",
    "                           kernel_regularizer=my_l1_regularizer,\n",
    "                           kernel_constraint=my_positive_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation=\"selu\", kernel_initializer=\"lecun_normal\",\n",
    "                       input_shape=input_shape),\n",
    "    keras.layers.Dense(1, activation=my_softplus,\n",
    "                       kernel_regularizer=my_l1_regularizer,\n",
    "                       kernel_constraint=my_positive_weights,\n",
    "                       kernel_initializer=my_glorot_initializer),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=\"mse\", optimizer=\"nadam\", metrics=[\"mae\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples, validate on 3870 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 82us/sample - loss: 1.4486 - mae: 0.8727 - val_loss: 2.4208 - val_mae: 0.5681\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 40us/sample - loss: 0.5848 - mae: 0.5260 - val_loss: 1.6040 - val_mae: 0.5122\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9fa8b77e90>"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled, y_train, epochs=2,\n",
    "          validation_data=(X_valid_scaled, y_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(\"my_model_with_many_custom_parts.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.load_model(\n",
    "    \"my_model_with_many_custom_parts.h5\",\n",
    "    custom_objects={\n",
    "       \"my_l1_regularizer\": my_l1_regularizer,\n",
    "       \"my_positive_weights\": my_positive_weights,\n",
    "       \"my_glorot_initializer\": my_glorot_initializer,\n",
    "       \"my_softplus\": my_softplus,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyL1Regularizer(keras.regularizers.Regularizer):\n",
    "    def __init__(self, factor):\n",
    "        self.factor = factor\n",
    "    def __call__(self, weights):\n",
    "        return tf.reduce_sum(tf.abs(self.factor * weights))\n",
    "    def get_config(self):\n",
    "        return {\"factor\": self.factor}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation=\"selu\", kernel_initializer=\"lecun_normal\",\n",
    "                       input_shape=input_shape),\n",
    "    keras.layers.Dense(1, activation=my_softplus,\n",
    "                       kernel_regularizer=MyL1Regularizer(0.01),\n",
    "                       kernel_constraint=my_positive_weights,\n",
    "                       kernel_initializer=my_glorot_initializer),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=\"mse\", optimizer=\"nadam\", metrics=[\"mae\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples, validate on 3870 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 70us/sample - loss: 1.4486 - mae: 0.8727 - val_loss: 2.4208 - val_mae: 0.5681\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 34us/sample - loss: 0.5848 - mae: 0.5260 - val_loss: 1.6040 - val_mae: 0.5122\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9fe0415850>"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled, y_train, epochs=2,\n",
    "          validation_data=(X_valid_scaled, y_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(\"my_model_with_many_custom_parts.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.load_model(\n",
    "    \"my_model_with_many_custom_parts.h5\",\n",
    "    custom_objects={\n",
    "       \"MyL1Regularizer\": MyL1Regularizer,\n",
    "       \"my_positive_weights\": my_positive_weights,\n",
    "       \"my_glorot_initializer\": my_glorot_initializer,\n",
    "       \"my_softplus\": my_softplus,\n",
    "    })"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation=\"selu\", kernel_initializer=\"lecun_normal\",\n",
    "                       input_shape=input_shape),\n",
    "    keras.layers.Dense(1),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=\"mse\", optimizer=\"nadam\", metrics=[create_huber(2.0)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 57us/sample - loss: 2.0326 - huber_fn: 0.9038\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 25us/sample - loss: 0.5862 - huber_fn: 0.2682\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9fb0f29650>"
      ]
     },
     "execution_count": 103,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled, y_train, epochs=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Warning**: if you use the same function as the loss and a metric, you may be surprised to see different results. This is generally just due to floating point precision errors: even though the mathematical equations are equivalent, the operations are not run in the same order, which can lead to small differences. Moreover, when using sample weights, there's more than just precision errors:\n",
    "* the loss since the start of the epoch is the mean of all batch losses seen so far. Each batch loss is the sum of the weighted instance losses divided by the _batch size_ (not the sum of weights, so the batch loss is _not_ the weighted mean of the losses).\n",
    "* the metric since the start of the epoch is equal to the sum of weighted instance losses divided by sum of all weights seen so far. In other words, it is the weighted mean of all the instance losses. Not the same thing.\n",
    "\n",
    "If you do the math, you will find that loss = metric * mean of sample weights (plus some floating point precision error)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=create_huber(2.0), optimizer=\"nadam\", metrics=[create_huber(2.0)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:sample_weight modes were coerced from\n",
      "  ...\n",
      "    to  \n",
      "  ['...']\n",
      "Train on 11610 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 66us/sample - loss: 0.1174 - huber_fn: 0.2375\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 27us/sample - loss: 0.1137 - huber_fn: 0.2306\n"
     ]
    }
   ],
   "source": [
    "sample_weight = np.random.rand(len(y_train))\n",
    "history = model.fit(X_train_scaled, y_train, epochs=2, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.11744698936265671, 0.11787283955197686)"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "history.history[\"loss\"][0], history.history[\"huber_fn\"][0] * sample_weight.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Streaming metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=0.8>"
      ]
     },
     "execution_count": 107,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision = keras.metrics.Precision()\n",
    "precision([0, 1, 1, 1, 0, 1, 0, 1], [1, 1, 0, 1, 0, 1, 0, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=0.5>"
      ]
     },
     "execution_count": 108,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision([0, 1, 0, 0, 1, 0, 1, 1], [1, 0, 1, 1, 0, 0, 0, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=0.5>"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision.result()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Variable 'true_positives:0' shape=(1,) dtype=float32, numpy=array([4.], dtype=float32)>,\n",
       " <tf.Variable 'false_positives:0' shape=(1,) dtype=float32, numpy=array([4.], dtype=float32)>]"
      ]
     },
     "execution_count": 110,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision.variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [],
   "source": [
    "precision.reset_states()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Creating a streaming metric:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HuberMetric(keras.metrics.Metric):\n",
    "    def __init__(self, threshold=1.0, **kwargs):\n",
    "        super().__init__(**kwargs) # handles base args (e.g., dtype)\n",
    "        self.threshold = threshold\n",
    "        #self.huber_fn = create_huber(threshold) # TODO: investigate why this fails\n",
    "        self.total = self.add_weight(\"total\", initializer=\"zeros\")\n",
    "        self.count = self.add_weight(\"count\", initializer=\"zeros\")\n",
    "    def huber_fn(self, y_true, y_pred): # workaround\n",
    "        error = y_true - y_pred\n",
    "        is_small_error = tf.abs(error) < self.threshold\n",
    "        squared_loss = tf.square(error) / 2\n",
    "        linear_loss  = self.threshold * tf.abs(error) - self.threshold**2 / 2\n",
    "        return tf.where(is_small_error, squared_loss, linear_loss)\n",
    "    def update_state(self, y_true, y_pred, sample_weight=None):\n",
    "        metric = self.huber_fn(y_true, y_pred)\n",
    "        self.total.assign_add(tf.reduce_sum(metric))\n",
    "        self.count.assign_add(tf.cast(tf.size(y_true), tf.float32))\n",
    "    def result(self):\n",
    "        return self.total / self.count\n",
    "    def get_config(self):\n",
    "        base_config = super().get_config()\n",
    "        return {**base_config, \"threshold\": self.threshold}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Warning**: when running the following cell, if you get autograph warnings such as `WARNING:tensorflow:AutoGraph could not transform [...] and will run it as-is`, then please install version 0.2.2 of the gast library (e.g., by running `!pip install gast==0.2.2`), then restart the kernel and run this notebook again from the beginning (see [autograph issue #1](https://github.com/tensorflow/autograph/issues/1) for more details):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=14.0>"
      ]
     },
     "execution_count": 113,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = HuberMetric(2.)\n",
    "\n",
    "# total = 2 * |10 - 2| - 2²/2 = 14\n",
    "# count = 1\n",
    "# result = 14 / 1 = 14\n",
    "m(tf.constant([[2.]]), tf.constant([[10.]])) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=7.0>"
      ]
     },
     "execution_count": 114,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# total = total + (|1 - 0|² / 2) + (2 * |9.25 - 5| - 2² / 2) = 14 + 7 = 21\n",
    "# count = count + 2 = 3\n",
    "# result = total / count = 21 / 3 = 7\n",
    "m(tf.constant([[0.], [5.]]), tf.constant([[1.], [9.25]]))\n",
    "\n",
    "m.result()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Variable 'total:0' shape=() dtype=float32, numpy=21.0>,\n",
       " <tf.Variable 'count:0' shape=() dtype=float32, numpy=3.0>]"
      ]
     },
     "execution_count": 115,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m.variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Variable 'total:0' shape=() dtype=float32, numpy=0.0>,\n",
       " <tf.Variable 'count:0' shape=() dtype=float32, numpy=0.0>]"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m.reset_states()\n",
    "m.variables"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's check that the `HuberMetric` class works well:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation=\"selu\", kernel_initializer=\"lecun_normal\",\n",
    "                       input_shape=input_shape),\n",
    "    keras.layers.Dense(1),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=create_huber(2.0), optimizer=\"nadam\", metrics=[HuberMetric(2.0)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 55us/sample - loss: 0.8646 - huber_metric: 0.8646\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 25us/sample - loss: 0.2564 - huber_metric: 0.2564\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9fa8dda290>"
      ]
     },
     "execution_count": 120,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled.astype(np.float32), y_train.astype(np.float32), epochs=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(\"my_model_with_a_custom_metric.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {},
   "outputs": [],
   "source": [
    "#model = keras.models.load_model(\"my_model_with_a_custom_metric.h5\",           # TODO: check PR #25956\n",
    "#                                custom_objects={\"huber_fn\": create_huber(2.0),\n",
    "#                                                \"HuberMetric\": HuberMetric})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 0s 25us/sample - loss: 0.2349 - huber_metric: 0.2349\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 24us/sample - loss: 0.2286 - huber_metric: 0.2286\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9fe0547e90>"
      ]
     },
     "execution_count": 123,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled.astype(np.float32), y_train.astype(np.float32), epochs=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Warning**: In TF 2.2, tf.keras adds an extra first metric in `model.metrics` at position 0 (see [TF issue #38150](https://github.com/tensorflow/tensorflow/issues/38150)). This forces us to use `model.metrics[-1]` rather than `model.metrics[0]` to access the `HuberMetric`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.0"
      ]
     },
     "execution_count": 124,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.metrics[-1].threshold"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looks like it works fine! More simply, we could have created the class like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HuberMetric(keras.metrics.Mean):\n",
    "    def __init__(self, threshold=1.0, name='HuberMetric', dtype=None):\n",
    "        self.threshold = threshold\n",
    "        self.huber_fn = create_huber(threshold)\n",
    "        super().__init__(name=name, dtype=dtype)\n",
    "    def update_state(self, y_true, y_pred, sample_weight=None):\n",
    "        metric = self.huber_fn(y_true, y_pred)\n",
    "        super(HuberMetric, self).update_state(metric, sample_weight)\n",
    "    def get_config(self):\n",
    "        base_config = super().get_config()\n",
    "        return {**base_config, \"threshold\": self.threshold}        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This class handles shapes better, and it also supports sample weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation=\"selu\", kernel_initializer=\"lecun_normal\",\n",
    "                       input_shape=input_shape),\n",
    "    keras.layers.Dense(1),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=keras.losses.Huber(2.0), optimizer=\"nadam\", weighted_metrics=[HuberMetric(2.0)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:sample_weight modes were coerced from\n",
      "  ...\n",
      "    to  \n",
      "  ['...']\n",
      "Train on 11610 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 70us/sample - loss: 0.4330 - HuberMetric: 0.8725\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 29us/sample - loss: 0.1286 - HuberMetric: 0.2592\n"
     ]
    }
   ],
   "source": [
    "sample_weight = np.random.rand(len(y_train))\n",
    "history = model.fit(X_train_scaled.astype(np.float32), y_train.astype(np.float32),\n",
    "                    epochs=2, sample_weight=sample_weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.43299613858249864, 0.4329960495905965)"
      ]
     },
     "execution_count": 130,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "history.history[\"loss\"][0], history.history[\"HuberMetric\"][0] * sample_weight.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(\"my_model_with_a_custom_metric_v2.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {},
   "outputs": [],
   "source": [
    "#model = keras.models.load_model(\"my_model_with_a_custom_metric_v2.h5\",        # TODO: check PR #25956\n",
    "#                                custom_objects={\"HuberMetric\": HuberMetric})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 57us/sample - loss: 0.2363 - HuberMetric: 0.2363\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 26us/sample - loss: 0.2274 - HuberMetric: 0.2274\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9fa64e9dd0>"
      ]
     },
     "execution_count": 133,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled.astype(np.float32), y_train.astype(np.float32), epochs=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Warning**: In TF 2.2, tf.keras adds an extra first metric in `model.metrics` at position 0 (see [TF issue #38150](https://github.com/tensorflow/tensorflow/issues/38150)). This forces us to use `model.metrics[-1]` rather than `model.metrics[0]` to access the `HuberMetric`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.0"
      ]
     },
     "execution_count": 134,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.metrics[-1].threshold"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom Layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [],
   "source": [
    "exponential_layer = keras.layers.Lambda(lambda x: tf.exp(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.36787945, 1.        , 2.7182817 ], dtype=float32)>"
      ]
     },
     "execution_count": 136,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "exponential_layer([-1., 0., 1.])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adding an exponential layer at the output of a regression model can be useful if the values to predict are positive and with very different scales (e.g., 0.001, 10., 10000):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples, validate on 3870 samples\n",
      "Epoch 1/5\n",
      "11610/11610 [==============================] - 1s 60us/sample - loss: nan - val_loss: nan\n",
      "Epoch 2/5\n",
      "11610/11610 [==============================] - 0s 30us/sample - loss: nan - val_loss: nan\n",
      "Epoch 3/5\n",
      "11610/11610 [==============================] - 0s 31us/sample - loss: nan - val_loss: nan\n",
      "Epoch 4/5\n",
      "11610/11610 [==============================] - 0s 33us/sample - loss: nan - val_loss: nan\n",
      "Epoch 5/5\n",
      "11610/11610 [==============================] - 0s 32us/sample - loss: nan - val_loss: nan\n",
      "5160/5160 [==============================] - 0s 15us/sample - loss: nan\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "nan"
      ]
     },
     "execution_count": 138,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation=\"relu\", input_shape=input_shape),\n",
    "    keras.layers.Dense(1),\n",
    "    exponential_layer\n",
    "])\n",
    "model.compile(loss=\"mse\", optimizer=\"nadam\")\n",
    "model.fit(X_train_scaled, y_train, epochs=5,\n",
    "          validation_data=(X_valid_scaled, y_valid))\n",
    "model.evaluate(X_test_scaled, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDense(keras.layers.Layer):\n",
    "    def __init__(self, units, activation=None, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.units = units\n",
    "        self.activation = keras.activations.get(activation)\n",
    "\n",
    "    def build(self, batch_input_shape):\n",
    "        self.kernel = self.add_weight(\n",
    "            name=\"kernel\", shape=[batch_input_shape[-1], self.units],\n",
    "            initializer=\"glorot_normal\")\n",
    "        self.bias = self.add_weight(\n",
    "            name=\"bias\", shape=[self.units], initializer=\"zeros\")\n",
    "        super().build(batch_input_shape) # must be at the end\n",
    "\n",
    "    def call(self, X):\n",
    "        return self.activation(X @ self.kernel + self.bias)\n",
    "\n",
    "    def compute_output_shape(self, batch_input_shape):\n",
    "        return tf.TensorShape(batch_input_shape.as_list()[:-1] + [self.units])\n",
    "\n",
    "    def get_config(self):\n",
    "        base_config = super().get_config()\n",
    "        return {**base_config, \"units\": self.units,\n",
    "                \"activation\": keras.activations.serialize(self.activation)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.Sequential([\n",
    "    MyDense(30, activation=\"relu\", input_shape=input_shape),\n",
    "    MyDense(1)\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples, validate on 3870 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 60us/sample - loss: 2.1425 - val_loss: 1.2779\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 32us/sample - loss: 0.6367 - val_loss: 0.5267\n",
      "5160/5160 [==============================] - 0s 15us/sample - loss: 0.5359\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.5359284550644631"
      ]
     },
     "execution_count": 142,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.compile(loss=\"mse\", optimizer=\"nadam\")\n",
    "model.fit(X_train_scaled, y_train, epochs=2,\n",
    "          validation_data=(X_valid_scaled, y_valid))\n",
    "model.evaluate(X_test_scaled, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(\"my_model_with_a_custom_layer.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.load_model(\"my_model_with_a_custom_layer.h5\",\n",
    "                                custom_objects={\"MyDense\": MyDense})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyMultiLayer(keras.layers.Layer):\n",
    "    def call(self, X):\n",
    "        X1, X2 = X\n",
    "        return X1 + X2, X1 * X2\n",
    "\n",
    "    def compute_output_shape(self, batch_input_shape):\n",
    "        batch_input_shape1, batch_input_shape2 = batch_input_shape\n",
    "        return [batch_input_shape1, batch_input_shape2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs1 = keras.layers.Input(shape=[2])\n",
    "inputs2 = keras.layers.Input(shape=[2])\n",
    "outputs1, outputs2 = MyMultiLayer()((inputs1, inputs2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a layer with a different behavior during training and testing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AddGaussianNoise(keras.layers.Layer):\n",
    "    def __init__(self, stddev, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.stddev = stddev\n",
    "\n",
    "    def call(self, X, training=None):\n",
    "        if training:\n",
    "            noise = tf.random.normal(tf.shape(X), stddev=self.stddev)\n",
    "            return X + noise\n",
    "        else:\n",
    "            return X\n",
    "\n",
    "    def compute_output_shape(self, batch_input_shape):\n",
    "        return batch_input_shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples, validate on 3870 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 61us/sample - loss: 0.4776 - val_loss: 0.5223\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 32us/sample - loss: 0.4213 - val_loss: 0.3807\n",
      "5160/5160 [==============================] - 0s 16us/sample - loss: 0.3977\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.3977465312610301"
      ]
     },
     "execution_count": 149,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.compile(loss=\"mse\", optimizer=\"nadam\")\n",
    "model.fit(X_train_scaled, y_train, epochs=2,\n",
    "          validation_data=(X_valid_scaled, y_valid))\n",
    "model.evaluate(X_test_scaled, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_new_scaled = X_test_scaled"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResidualBlock(keras.layers.Layer):\n",
    "    def __init__(self, n_layers, n_neurons, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.hidden = [keras.layers.Dense(n_neurons, activation=\"elu\",\n",
    "                                          kernel_initializer=\"he_normal\")\n",
    "                       for _ in range(n_layers)]\n",
    "\n",
    "    def call(self, inputs):\n",
    "        Z = inputs\n",
    "        for layer in self.hidden:\n",
    "            Z = layer(Z)\n",
    "        return inputs + Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResidualRegressor(keras.models.Model):\n",
    "    def __init__(self, output_dim, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.hidden1 = keras.layers.Dense(30, activation=\"elu\",\n",
    "                                          kernel_initializer=\"he_normal\")\n",
    "        self.block1 = ResidualBlock(2, 30)\n",
    "        self.block2 = ResidualBlock(2, 30)\n",
    "        self.out = keras.layers.Dense(output_dim)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        Z = self.hidden1(inputs)\n",
    "        for _ in range(1 + 3):\n",
    "            Z = self.block1(Z)\n",
    "        Z = self.block2(Z)\n",
    "        return self.out(Z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples\n",
      "Epoch 1/5\n",
      "11610/11610 [==============================] - 1s 104us/sample - loss: 11.4379\n",
      "Epoch 2/5\n",
      "11610/11610 [==============================] - 0s 41us/sample - loss: 1.7623\n",
      "Epoch 3/5\n",
      "11610/11610 [==============================] - 0s 42us/sample - loss: 1.4744\n",
      "Epoch 4/5\n",
      "11610/11610 [==============================] - 0s 41us/sample - loss: 0.6444\n",
      "Epoch 5/5\n",
      "11610/11610 [==============================] - 0s 41us/sample - loss: 0.5185\n",
      "5160/5160 [==============================] - 0s 32us/sample - loss: 0.6047\n"
     ]
    }
   ],
   "source": [
    "model = ResidualRegressor(1)\n",
    "model.compile(loss=\"mse\", optimizer=\"nadam\")\n",
    "history = model.fit(X_train_scaled, y_train, epochs=5)\n",
    "score = model.evaluate(X_test_scaled, y_test)\n",
    "y_pred = model.predict(X_new_scaled)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Users/ageron/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n",
      "INFO:tensorflow:Assets written to: my_custom_model.ckpt/assets\n"
     ]
    }
   ],
   "source": [
    "model.save(\"my_custom_model.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.load_model(\"my_custom_model.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples\n",
      "Epoch 1/5\n",
      "11610/11610 [==============================] - 1s 123us/sample - loss: 0.6816\n",
      "Epoch 2/5\n",
      "11610/11610 [==============================] - 0s 42us/sample - loss: 0.5066\n",
      "Epoch 3/5\n",
      "11610/11610 [==============================] - 0s 42us/sample - loss: 0.6510\n",
      "Epoch 4/5\n",
      "11610/11610 [==============================] - 0s 43us/sample - loss: 0.6147\n",
      "Epoch 5/5\n",
      "11610/11610 [==============================] - 0s 42us/sample - loss: 0.9777\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(X_train_scaled, y_train, epochs=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We could have defined the model using the sequential API instead:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "metadata": {},
   "outputs": [],
   "source": [
    "block1 = ResidualBlock(2, 30)\n",
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation=\"elu\", kernel_initializer=\"he_normal\"),\n",
    "    block1, block1, block1, block1,\n",
    "    ResidualBlock(2, 30),\n",
    "    keras.layers.Dense(1)\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples\n",
      "Epoch 1/5\n",
      "11610/11610 [==============================] - 1s 88us/sample - loss: 0.9281\n",
      "Epoch 2/5\n",
      "11610/11610 [==============================] - 0s 38us/sample - loss: 0.4480\n",
      "Epoch 3/5\n",
      "11610/11610 [==============================] - 0s 37us/sample - loss: 0.7364\n",
      "Epoch 4/5\n",
      "11610/11610 [==============================] - 0s 37us/sample - loss: 0.4003\n",
      "Epoch 5/5\n",
      "11610/11610 [==============================] - 0s 39us/sample - loss: 0.5058\n",
      "5160/5160 [==============================] - 0s 29us/sample - loss: 0.4289\n"
     ]
    }
   ],
   "source": [
    "model.compile(loss=\"mse\", optimizer=\"nadam\")\n",
    "history = model.fit(X_train_scaled, y_train, epochs=5)\n",
    "score = model.evaluate(X_test_scaled, y_test)\n",
    "y_pred = model.predict(X_new_scaled)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Losses and Metrics Based on Model Internals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReconstructingRegressor(keras.models.Model):\n",
    "    def __init__(self, output_dim, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.hidden = [keras.layers.Dense(30, activation=\"selu\",\n",
    "                                          kernel_initializer=\"lecun_normal\")\n",
    "                       for _ in range(5)]\n",
    "        self.out = keras.layers.Dense(output_dim)\n",
    "        # TODO: check https://github.com/tensorflow/tensorflow/issues/26260\n",
    "        #self.reconstruction_mean = keras.metrics.Mean(name=\"reconstruction_error\")\n",
    "\n",
    "    def build(self, batch_input_shape):\n",
    "        n_inputs = batch_input_shape[-1]\n",
    "        self.reconstruct = keras.layers.Dense(n_inputs)\n",
    "        super().build(batch_input_shape)\n",
    "\n",
    "    def call(self, inputs, training=None):\n",
    "        Z = inputs\n",
    "        for layer in self.hidden:\n",
    "            Z = layer(Z)\n",
    "        reconstruction = self.reconstruct(Z)\n",
    "        recon_loss = tf.reduce_mean(tf.square(reconstruction - inputs))\n",
    "        self.add_loss(0.05 * recon_loss)\n",
    "        #if training:\n",
    "        #    result = self.reconstruction_mean(recon_loss)\n",
    "        #    self.add_metric(result)\n",
    "        return self.out(Z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples\n",
      "Epoch 1/2\n",
      "11610/11610 [==============================] - 1s 98us/sample - loss: 0.7647\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 40us/sample - loss: 0.4112\n"
     ]
    }
   ],
   "source": [
    "model = ReconstructingRegressor(1)\n",
    "model.compile(loss=\"mse\", optimizer=\"nadam\")\n",
    "history = model.fit(X_train_scaled, y_train, epochs=2)\n",
    "y_pred = model.predict(X_test_scaled)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computing Gradients with Autodiff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 164,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(w1, w2):\n",
    "    return 3 * w1 ** 2 + 2 * w1 * w2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "36.000003007075065"
      ]
     },
     "execution_count": 165,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w1, w2 = 5, 3\n",
    "eps = 1e-6\n",
    "(f(w1 + eps, w2) - f(w1, w2)) / eps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10.000000003174137"
      ]
     },
     "execution_count": 166,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(f(w1, w2 + eps) - f(w1, w2)) / eps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "metadata": {},
   "outputs": [],
   "source": [
    "w1, w2 = tf.Variable(5.), tf.Variable(3.)\n",
    "with tf.GradientTape() as tape:\n",
    "    z = f(w1, w2)\n",
    "\n",
    "gradients = tape.gradient(z, [w1, w2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor: shape=(), dtype=float32, numpy=36.0>,\n",
       " <tf.Tensor: shape=(), dtype=float32, numpy=10.0>]"
      ]
     },
     "execution_count": 168,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GradientTape.gradient can only be called once on non-persistent tapes.\n"
     ]
    }
   ],
   "source": [
    "with tf.GradientTape() as tape:\n",
    "    z = f(w1, w2)\n",
    "\n",
    "dz_dw1 = tape.gradient(z, w1)\n",
    "try:\n",
    "    dz_dw2 = tape.gradient(z, w2)\n",
    "except RuntimeError as ex:\n",
    "    print(ex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.GradientTape(persistent=True) as tape:\n",
    "    z = f(w1, w2)\n",
    "\n",
    "dz_dw1 = tape.gradient(z, w1)\n",
    "dz_dw2 = tape.gradient(z, w2) # works now!\n",
    "del tape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<tf.Tensor: shape=(), dtype=float32, numpy=36.0>,\n",
       " <tf.Tensor: shape=(), dtype=float32, numpy=10.0>)"
      ]
     },
     "execution_count": 171,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dz_dw1, dz_dw2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "metadata": {},
   "outputs": [],
   "source": [
    "c1, c2 = tf.constant(5.), tf.constant(3.)\n",
    "with tf.GradientTape() as tape:\n",
    "    z = f(c1, c2)\n",
    "\n",
    "gradients = tape.gradient(z, [c1, c2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[None, None]"
      ]
     },
     "execution_count": 173,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.GradientTape() as tape:\n",
    "    tape.watch(c1)\n",
    "    tape.watch(c2)\n",
    "    z = f(c1, c2)\n",
    "\n",
    "gradients = tape.gradient(z, [c1, c2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor: shape=(), dtype=float32, numpy=36.0>,\n",
       " <tf.Tensor: shape=(), dtype=float32, numpy=10.0>]"
      ]
     },
     "execution_count": 175,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor: shape=(), dtype=float32, numpy=136.0>,\n",
       " <tf.Tensor: shape=(), dtype=float32, numpy=30.0>]"
      ]
     },
     "execution_count": 176,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with tf.GradientTape() as tape:\n",
    "    z1 = f(w1, w2 + 2.)\n",
    "    z2 = f(w1, w2 + 5.)\n",
    "    z3 = f(w1, w2 + 7.)\n",
    "\n",
    "tape.gradient([z1, z2, z3], [w1, w2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.GradientTape(persistent=True) as tape:\n",
    "    z1 = f(w1, w2 + 2.)\n",
    "    z2 = f(w1, w2 + 5.)\n",
    "    z3 = f(w1, w2 + 7.)\n",
    "\n",
    "tf.reduce_sum(tf.stack([tape.gradient(z, [w1, w2]) for z in (z1, z2, z3)]), axis=0)\n",
    "del tape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.GradientTape(persistent=True) as hessian_tape:\n",
    "    with tf.GradientTape() as jacobian_tape:\n",
    "        z = f(w1, w2)\n",
    "    jacobians = jacobian_tape.gradient(z, [w1, w2])\n",
    "hessians = [hessian_tape.gradient(jacobian, [w1, w2])\n",
    "            for jacobian in jacobians]\n",
    "del hessian_tape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor: shape=(), dtype=float32, numpy=36.0>,\n",
       " <tf.Tensor: shape=(), dtype=float32, numpy=10.0>]"
      ]
     },
     "execution_count": 179,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jacobians"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[<tf.Tensor: shape=(), dtype=float32, numpy=6.0>,\n",
       "  <tf.Tensor: shape=(), dtype=float32, numpy=2.0>],\n",
       " [<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, None]]"
      ]
     },
     "execution_count": 180,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hessians"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor: shape=(), dtype=float32, numpy=30.0>, None]"
      ]
     },
     "execution_count": 181,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def f(w1, w2):\n",
    "    return 3 * w1 ** 2 + tf.stop_gradient(2 * w1 * w2)\n",
    "\n",
    "with tf.GradientTape() as tape:\n",
    "    z = f(w1, w2)\n",
    "\n",
    "tape.gradient(z, [w1, w2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor: shape=(), dtype=float32, numpy=nan>]"
      ]
     },
     "execution_count": 182,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = tf.Variable(100.)\n",
    "with tf.GradientTape() as tape:\n",
    "    z = my_softplus(x)\n",
    "\n",
    "tape.gradient(z, [x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=30.0>"
      ]
     },
     "execution_count": 183,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.math.log(tf.exp(tf.constant(30., dtype=tf.float32)) + 1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 184,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([nan], dtype=float32)>]"
      ]
     },
     "execution_count": 184,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = tf.Variable([100.])\n",
    "with tf.GradientTape() as tape:\n",
    "    z = my_softplus(x)\n",
    "\n",
    "tape.gradient(z, [x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.custom_gradient\n",
    "def my_better_softplus(z):\n",
    "    exp = tf.exp(z)\n",
    "    def my_softplus_gradients(grad):\n",
    "        return grad / (1 + 1 / exp)\n",
    "    return tf.math.log(exp + 1), my_softplus_gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 186,
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_better_softplus(z):\n",
    "    return tf.where(z > 30., z, tf.math.log(tf.exp(z) + 1.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 187,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<tf.Tensor: shape=(1,), dtype=float32, numpy=array([1000.], dtype=float32)>,\n",
       " [<tf.Tensor: shape=(1,), dtype=float32, numpy=array([nan], dtype=float32)>])"
      ]
     },
     "execution_count": 187,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = tf.Variable([1000.])\n",
    "with tf.GradientTape() as tape:\n",
    "    z = my_better_softplus(x)\n",
    "\n",
    "z, tape.gradient(z, [x])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Computing Gradients Using Autodiff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 189,
   "metadata": {},
   "outputs": [],
   "source": [
    "l2_reg = keras.regularizers.l2(0.05)\n",
    "model = keras.models.Sequential([\n",
    "    keras.layers.Dense(30, activation=\"elu\", kernel_initializer=\"he_normal\",\n",
    "                       kernel_regularizer=l2_reg),\n",
    "    keras.layers.Dense(1, kernel_regularizer=l2_reg)\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 190,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_batch(X, y, batch_size=32):\n",
    "    idx = np.random.randint(len(X), size=batch_size)\n",
    "    return X[idx], y[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 191,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_status_bar(iteration, total, loss, metrics=None):\n",
    "    metrics = \" - \".join([\"{}: {:.4f}\".format(m.name, m.result())\n",
    "                         for m in [loss] + (metrics or [])])\n",
    "    end = \"\" if iteration < total else \"\\n\"\n",
    "    print(\"\\r{}/{} - \".format(iteration, total) + metrics,\n",
    "          end=end)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "50/50 - loss: 0.0900 - mean_square: 858.5000\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "mean_loss = keras.metrics.Mean(name=\"loss\")\n",
    "mean_square = keras.metrics.Mean(name=\"mean_square\")\n",
    "for i in range(1, 50 + 1):\n",
    "    loss = 1 / i\n",
    "    mean_loss(loss)\n",
    "    mean_square(i ** 2)\n",
    "    print_status_bar(i, 50, mean_loss, [mean_square])\n",
    "    time.sleep(0.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A fancier version with a progress bar:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "metadata": {},
   "outputs": [],
   "source": [
    "def progress_bar(iteration, total, size=30):\n",
    "    running = iteration < total\n",
    "    c = \">\" if running else \"=\"\n",
    "    p = (size - 1) * iteration // total\n",
    "    fmt = \"{{:-{}d}}/{{}} [{{}}]\".format(len(str(total)))\n",
    "    params = [iteration, total, \"=\" * p + c + \".\" * (size - p - 1)]\n",
    "    return fmt.format(*params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 194,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' 3500/10000 [=>....]'"
      ]
     },
     "execution_count": 194,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "progress_bar(3500, 10000, size=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 195,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_status_bar(iteration, total, loss, metrics=None, size=30):\n",
    "    metrics = \" - \".join([\"{}: {:.4f}\".format(m.name, m.result())\n",
    "                         for m in [loss] + (metrics or [])])\n",
    "    end = \"\" if iteration < total else \"\\n\"\n",
    "    print(\"\\r{} - {}\".format(progress_bar(iteration, total), metrics), end=end)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 196,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "50/50 [==============================] - loss: 0.0900 - mean_square: 858.5000\n"
     ]
    }
   ],
   "source": [
    "mean_loss = keras.metrics.Mean(name=\"loss\")\n",
    "mean_square = keras.metrics.Mean(name=\"mean_square\")\n",
    "for i in range(1, 50 + 1):\n",
    "    loss = 1 / i\n",
    "    mean_loss(loss)\n",
    "    mean_square(i ** 2)\n",
    "    print_status_bar(i, 50, mean_loss, [mean_square])\n",
    "    time.sleep(0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 197,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 198,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 5\n",
    "batch_size = 32\n",
    "n_steps = len(X_train) // batch_size\n",
    "optimizer = keras.optimizers.Nadam(lr=0.01)\n",
    "loss_fn = keras.losses.mean_squared_error\n",
    "mean_loss = keras.metrics.Mean()\n",
    "metrics = [keras.metrics.MeanAbsoluteError()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 199,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "WARNING:tensorflow:Layer dense is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.\n",
      "\n",
      "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
      "\n",
      "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
      "\n",
      "11610/11610 [==============================] - mean: 1.3955 - mean_absolute_error: 0.5722\n",
      "Epoch 2/5\n",
      "11610/11610 [==============================] - mean: 0.6774 - mean_absolute_error: 0.5280\n",
      "Epoch 3/5\n",
      "11610/11610 [==============================] - mean: 0.6351 - mean_absolute_error: 0.5177\n",
      "Epoch 4/5\n",
      "11610/11610 [==============================] - mean: 0.6384 - mean_absolute_error: 0.5181\n",
      "Epoch 5/5\n",
      "11610/11610 [==============================] - mean: 0.6440 - mean_absolute_error: 0.5222\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, n_epochs + 1):\n",
    "    print(\"Epoch {}/{}\".format(epoch, n_epochs))\n",
    "    for step in range(1, n_steps + 1):\n",
    "        X_batch, y_batch = random_batch(X_train_scaled, y_train)\n",
    "        with tf.GradientTape() as tape:\n",
    "            y_pred = model(X_batch)\n",
    "            main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))\n",
    "            loss = tf.add_n([main_loss] + model.losses)\n",
    "        gradients = tape.gradient(loss, model.trainable_variables)\n",
    "        optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "        for variable in model.variables:\n",
    "            if variable.constraint is not None:\n",
    "                variable.assign(variable.constraint(variable))\n",
    "        mean_loss(loss)\n",
    "        for metric in metrics:\n",
    "            metric(y_batch, y_pred)\n",
    "        print_status_bar(step * batch_size, len(y_train), mean_loss, metrics)\n",
    "    print_status_bar(len(y_train), len(y_train), mean_loss, metrics)\n",
    "    for metric in [mean_loss] + metrics:\n",
    "        metric.reset_states()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 200,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6cd1de57d3e5446ca07bed7a6b4b0a6e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='All epochs', max=5.0, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e50e4dbb3426481db4c12e54360fadd3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 1/5', max=362.0, style=ProgressStyle(description_wi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b695174606604686b3875c08a6dabebe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 2/5', max=362.0, style=ProgressStyle(description_wi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3f4ac210b9a94ece85faef23bd4ebd2c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 3/5', max=362.0, style=ProgressStyle(description_wi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d90b821dd6ad41b9977c9796a3e58962",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 4/5', max=362.0, style=ProgressStyle(description_wi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f7f6c4454bef41b581b4320751ca1a46",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 5/5', max=362.0, style=ProgressStyle(description_wi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    from tqdm.notebook import trange\n",
    "    from collections import OrderedDict\n",
    "    with trange(1, n_epochs + 1, desc=\"All epochs\") as epochs:\n",
    "        for epoch in epochs:\n",
    "            with trange(1, n_steps + 1, desc=\"Epoch {}/{}\".format(epoch, n_epochs)) as steps:\n",
    "                for step in steps:\n",
    "                    X_batch, y_batch = random_batch(X_train_scaled, y_train)\n",
    "                    with tf.GradientTape() as tape:\n",
    "                        y_pred = model(X_batch)\n",
    "                        main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))\n",
    "                        loss = tf.add_n([main_loss] + model.losses)\n",
    "                    gradients = tape.gradient(loss, model.trainable_variables)\n",
    "                    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "                    for variable in model.variables:\n",
    "                        if variable.constraint is not None:\n",
    "                            variable.assign(variable.constraint(variable))                    \n",
    "                    status = OrderedDict()\n",
    "                    mean_loss(loss)\n",
    "                    status[\"loss\"] = mean_loss.result().numpy()\n",
    "                    for metric in metrics:\n",
    "                        metric(y_batch, y_pred)\n",
    "                        status[metric.name] = metric.result().numpy()\n",
    "                    steps.set_postfix(status)\n",
    "            for metric in [mean_loss] + metrics:\n",
    "                metric.reset_states()\n",
    "except ImportError as ex:\n",
    "    print(\"To run this cell, please install tqdm, ipywidgets and restart Jupyter\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TensorFlow Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cube(x):\n",
    "    return x ** 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8"
      ]
     },
     "execution_count": 202,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cube(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=8.0>"
      ]
     },
     "execution_count": 203,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cube(tf.constant(2.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 204,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.eager.def_function.Function at 0x7f9fa612af90>"
      ]
     },
     "execution_count": 204,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf_cube = tf.function(cube)\n",
    "tf_cube"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=int32, numpy=8>"
      ]
     },
     "execution_count": 205,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf_cube(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 206,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=8.0>"
      ]
     },
     "execution_count": 206,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf_cube(tf.constant(2.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TF Functions and Concrete Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 207,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.framework.func_graph.FuncGraph at 0x7f9f90165690>"
      ]
     },
     "execution_count": 207,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concrete_function = tf_cube.get_concrete_function(tf.constant(2.0))\n",
    "concrete_function.graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 208,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=8.0>"
      ]
     },
     "execution_count": 208,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concrete_function(tf.constant(2.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 209,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 209,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concrete_function is tf_cube.get_concrete_function(tf.constant(2.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exploring Function Definitions and Graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 210,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.framework.func_graph.FuncGraph at 0x7f9f90165690>"
      ]
     },
     "execution_count": 210,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concrete_function.graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 211,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Operation 'x' type=Placeholder>,\n",
       " <tf.Operation 'pow/y' type=Const>,\n",
       " <tf.Operation 'pow' type=Pow>,\n",
       " <tf.Operation 'Identity' type=Identity>]"
      ]
     },
     "execution_count": 211,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ops = concrete_function.graph.get_operations()\n",
    "ops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 212,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor 'x:0' shape=() dtype=float32>,\n",
       " <tf.Tensor 'pow/y:0' shape=() dtype=float32>]"
      ]
     },
     "execution_count": 212,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pow_op = ops[2]\n",
    "list(pow_op.inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 213,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor 'pow:0' shape=() dtype=float32>]"
      ]
     },
     "execution_count": 213,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pow_op.outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 214,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Operation 'x' type=Placeholder>"
      ]
     },
     "execution_count": 214,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concrete_function.graph.get_operation_by_name('x')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 215,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor 'Identity:0' shape=() dtype=float32>"
      ]
     },
     "execution_count": 215,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concrete_function.graph.get_tensor_by_name('Identity:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 216,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "name: \"__inference_cube_999576\"\n",
       "input_arg {\n",
       "  name: \"x\"\n",
       "  type: DT_FLOAT\n",
       "}\n",
       "output_arg {\n",
       "  name: \"identity\"\n",
       "  type: DT_FLOAT\n",
       "}"
      ]
     },
     "execution_count": 216,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concrete_function.function_def.signature"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### How TF Functions Trace Python Functions to Extract Their Computation Graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 217,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def tf_cube(x):\n",
    "    print(\"print:\", x)\n",
    "    return x ** 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 218,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "print: Tensor(\"x:0\", shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "result = tf_cube(tf.constant(2.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 219,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=8.0>"
      ]
     },
     "execution_count": 219,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 220,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "print: 2\n",
      "print: 3\n",
      "print: Tensor(\"x:0\", shape=(1, 2), dtype=float32)\n",
      "print: Tensor(\"x:0\", shape=(2, 2), dtype=float32)\n",
      "WARNING:tensorflow:5 out of the last 5 calls to <function tf_cube at 0x7f9f9064c320> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
      "print: Tensor(\"x:0\", shape=(3, 2), dtype=float32)\n",
      "WARNING:tensorflow:6 out of the last 6 calls to <function tf_cube at 0x7f9f9064c320> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n"
     ]
    }
   ],
   "source": [
    "result = tf_cube(2)\n",
    "result = tf_cube(3)\n",
    "result = tf_cube(tf.constant([[1., 2.]])) # New shape: trace!\n",
    "result = tf_cube(tf.constant([[3., 4.], [5., 6.]])) # New shape: trace!\n",
    "result = tf_cube(tf.constant([[7., 8.], [9., 10.], [11., 12.]])) # no trace"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is also possible to specify a particular input signature:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 221,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function(input_signature=[tf.TensorSpec([None, 28, 28], tf.float32)])\n",
    "def shrink(images):\n",
    "    print(\"Tracing\", images)\n",
    "    return images[:, ::2, ::2] # drop half the rows and columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 222,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 223,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tracing Tensor(\"images:0\", shape=(None, 28, 28), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "img_batch_1 = tf.random.uniform(shape=[100, 28, 28])\n",
    "img_batch_2 = tf.random.uniform(shape=[50, 28, 28])\n",
    "preprocessed_images = shrink(img_batch_1) # Traces the function.\n",
    "preprocessed_images = shrink(img_batch_2) # Reuses the same concrete function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 224,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Python inputs incompatible with input_signature:\n",
      "  inputs: (\n",
      "    tf.Tensor(\n",
      "[[[0.7413678  0.62854624]\n",
      "  [0.01738465 0.3431449 ]]\n",
      "\n",
      " [[0.51063764 0.3777541 ]\n",
      "  [0.07321596 0.02137029]]], shape=(2, 2, 2), dtype=float32))\n",
      "  input_signature: (\n",
      "    TensorSpec(shape=(None, 28, 28), dtype=tf.float32, name=None))\n"
     ]
    }
   ],
   "source": [
    "img_batch_3 = tf.random.uniform(shape=[2, 2, 2])\n",
    "try:\n",
    "    preprocessed_images = shrink(img_batch_3)  # rejects unexpected types or shapes\n",
    "except ValueError as ex:\n",
    "    print(ex)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using Autograph To Capture Control Flow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A \"static\" `for` loop using `range()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 225,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def add_10(x):\n",
    "    for i in range(10):\n",
    "        x += 1\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 226,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=int32, numpy=15>"
      ]
     },
     "execution_count": 226,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "add_10(tf.constant(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 227,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Operation 'x' type=Placeholder>,\n",
       " <tf.Operation 'add/y' type=Const>,\n",
       " <tf.Operation 'add' type=AddV2>,\n",
       " <tf.Operation 'add_1/y' type=Const>,\n",
       " <tf.Operation 'add_1' type=AddV2>,\n",
       " <tf.Operation 'add_2/y' type=Const>,\n",
       " <tf.Operation 'add_2' type=AddV2>,\n",
       " <tf.Operation 'add_3/y' type=Const>,\n",
       " <tf.Operation 'add_3' type=AddV2>,\n",
       " <tf.Operation 'add_4/y' type=Const>,\n",
       " <tf.Operation 'add_4' type=AddV2>,\n",
       " <tf.Operation 'add_5/y' type=Const>,\n",
       " <tf.Operation 'add_5' type=AddV2>,\n",
       " <tf.Operation 'add_6/y' type=Const>,\n",
       " <tf.Operation 'add_6' type=AddV2>,\n",
       " <tf.Operation 'add_7/y' type=Const>,\n",
       " <tf.Operation 'add_7' type=AddV2>,\n",
       " <tf.Operation 'add_8/y' type=Const>,\n",
       " <tf.Operation 'add_8' type=AddV2>,\n",
       " <tf.Operation 'add_9/y' type=Const>,\n",
       " <tf.Operation 'add_9' type=AddV2>,\n",
       " <tf.Operation 'Identity' type=Identity>]"
      ]
     },
     "execution_count": 227,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "add_10.get_concrete_function(tf.constant(5)).graph.get_operations()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A \"dynamic\" loop using `tf.while_loop()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 228,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def add_10(x):\n",
    "    condition = lambda i, x: tf.less(i, 10)\n",
    "    body = lambda i, x: (tf.add(i, 1), tf.add(x, 1))\n",
    "    final_i, final_x = tf.while_loop(condition, body, [tf.constant(0), x])\n",
    "    return final_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 229,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=int32, numpy=15>"
      ]
     },
     "execution_count": 229,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "add_10(tf.constant(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 230,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Operation 'x' type=Placeholder>,\n",
       " <tf.Operation 'Const' type=Const>,\n",
       " <tf.Operation 'while/maximum_iterations' type=Const>,\n",
       " <tf.Operation 'while/loop_counter' type=Const>,\n",
       " <tf.Operation 'while' type=StatelessWhile>,\n",
       " <tf.Operation 'Identity' type=Identity>]"
      ]
     },
     "execution_count": 230,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "add_10.get_concrete_function(tf.constant(5)).graph.get_operations()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A \"dynamic\" `for` loop using `tf.range()` (captured by autograph):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 231,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def add_10(x):\n",
    "    for i in tf.range(10):\n",
    "        x = x + 1\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 232,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Operation 'x' type=Placeholder>,\n",
       " <tf.Operation 'range/start' type=Const>,\n",
       " <tf.Operation 'range/limit' type=Const>,\n",
       " <tf.Operation 'range/delta' type=Const>,\n",
       " <tf.Operation 'range' type=Range>,\n",
       " <tf.Operation 'while/maximum_iterations' type=Const>,\n",
       " <tf.Operation 'while/loop_counter' type=Const>,\n",
       " <tf.Operation 'while' type=StatelessWhile>,\n",
       " <tf.Operation 'Identity' type=Identity>]"
      ]
     },
     "execution_count": 232,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "add_10.get_concrete_function(tf.constant(0)).graph.get_operations()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Handling Variables and Other Resources in TF Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 233,
   "metadata": {},
   "outputs": [],
   "source": [
    "counter = tf.Variable(0)\n",
    "\n",
    "@tf.function\n",
    "def increment(counter, c=1):\n",
    "    return counter.assign_add(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 234,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=int32, numpy=2>"
      ]
     },
     "execution_count": 234,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "increment(counter)\n",
    "increment(counter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 235,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "name: \"counter\"\n",
       "type: DT_RESOURCE"
      ]
     },
     "execution_count": 235,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function_def = increment.get_concrete_function(counter).function_def\n",
    "function_def.signature.input_arg[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 236,
   "metadata": {},
   "outputs": [],
   "source": [
    "counter = tf.Variable(0)\n",
    "\n",
    "@tf.function\n",
    "def increment(c=1):\n",
    "    return counter.assign_add(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 237,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=int32, numpy=2>"
      ]
     },
     "execution_count": 237,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "increment()\n",
    "increment()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 238,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "name: \"assignaddvariableop_resource\"\n",
       "type: DT_RESOURCE"
      ]
     },
     "execution_count": 238,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function_def = increment.get_concrete_function().function_def\n",
    "function_def.signature.input_arg[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 239,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Counter:\n",
    "    def __init__(self):\n",
    "        self.counter = tf.Variable(0)\n",
    "\n",
    "    @tf.function\n",
    "    def increment(self, c=1):\n",
    "        return self.counter.assign_add(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 240,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=int32, numpy=2>"
      ]
     },
     "execution_count": 240,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c = Counter()\n",
    "c.increment()\n",
    "c.increment()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 241,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"def tf__add_10(x):\\n  do_return = False\\n  retval_ = ag__.UndefinedReturnValue()\\n  with ag__.FunctionScope('add_10', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:\\n\\n    def get_state():\\n      return ()\\n\\n    def set_state(_):\\n      pass\\n\\n    def loop_body(iterates, x):\\n      i = iterates\\n      x += 1\\n      return x,\\n    x, = ag__.for_stmt(ag__.converted_call(tf.range, (10,), None, fscope), None, loop_body, get_state, set_state, (x,), ('x',), ())\\n    do_return = True\\n    retval_ = fscope.mark_return_value(x)\\n  do_return,\\n  return ag__.retval(retval_)\\n\""
      ]
     },
     "execution_count": 241,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@tf.function\n",
    "def add_10(x):\n",
    "    for i in tf.range(10):\n",
    "        x += 1\n",
    "    return x\n",
    "\n",
    "tf.autograph.to_code(add_10.python_function)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 242,
   "metadata": {},
   "outputs": [],
   "source": [
    "def display_tf_code(func):\n",
    "    from IPython.display import display, Markdown\n",
    "    if hasattr(func, \"python_function\"):\n",
    "        func = func.python_function\n",
    "    code = tf.autograph.to_code(func)\n",
    "    display(Markdown('```python\\n{}\\n```'.format(code)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 243,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "def tf__add_10(x):\n",
       "  do_return = False\n",
       "  retval_ = ag__.UndefinedReturnValue()\n",
       "  with ag__.FunctionScope('add_10', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:\n",
       "\n",
       "    def get_state():\n",
       "      return ()\n",
       "\n",
       "    def set_state(_):\n",
       "      pass\n",
       "\n",
       "    def loop_body(iterates, x):\n",
       "      i = iterates\n",
       "      x += 1\n",
       "      return x,\n",
       "    x, = ag__.for_stmt(ag__.converted_call(tf.range, (10,), None, fscope), None, loop_body, get_state, set_state, (x,), ('x',), ())\n",
       "    do_return = True\n",
       "    retval_ = fscope.mark_return_value(x)\n",
       "  do_return,\n",
       "  return ag__.retval(retval_)\n",
       "\n",
       "```"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display_tf_code(add_10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using TF Functions with tf.keras (or Not)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, tf.keras will automatically convert your custom code into TF Functions, no need to use\n",
    "`tf.function()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 244,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Custom loss function\n",
    "def my_mse(y_true, y_pred):\n",
    "    print(\"Tracing loss my_mse()\")\n",
    "    return tf.reduce_mean(tf.square(y_pred - y_true))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 245,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Custom metric function\n",
    "def my_mae(y_true, y_pred):\n",
    "    print(\"Tracing metric my_mae()\")\n",
    "    return tf.reduce_mean(tf.abs(y_pred - y_true))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 246,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Custom layer\n",
    "class MyDense(keras.layers.Layer):\n",
    "    def __init__(self, units, activation=None, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.units = units\n",
    "        self.activation = keras.activations.get(activation)\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        self.kernel = self.add_weight(name='kernel', \n",
    "                                      shape=(input_shape[1], self.units),\n",
    "                                      initializer='uniform',\n",
    "                                      trainable=True)\n",
    "        self.biases = self.add_weight(name='bias', \n",
    "                                      shape=(self.units,),\n",
    "                                      initializer='zeros',\n",
    "                                      trainable=True)\n",
    "        super().build(input_shape)\n",
    "\n",
    "    def call(self, X):\n",
    "        print(\"Tracing MyDense.call()\")\n",
    "        return self.activation(X @ self.kernel + self.biases)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 247,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 248,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Custom model\n",
    "class MyModel(keras.models.Model):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.hidden1 = MyDense(30, activation=\"relu\")\n",
    "        self.hidden2 = MyDense(30, activation=\"relu\")\n",
    "        self.output_ = MyDense(1)\n",
    "\n",
    "    def call(self, input):\n",
    "        print(\"Tracing MyModel.call()\")\n",
    "        hidden1 = self.hidden1(input)\n",
    "        hidden2 = self.hidden2(hidden1)\n",
    "        concat = keras.layers.concatenate([input, hidden2])\n",
    "        output = self.output_(concat)\n",
    "        return output\n",
    "\n",
    "model = MyModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 249,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=my_mse, optimizer=\"nadam\", metrics=[my_mae])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 250,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing metric my_mae()\n",
      "Tracing loss my_mse()\n",
      "Train on 11610 samples, validate on 3870 samples\n",
      "Epoch 1/2\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      " 9888/11610 [========================>.....] - ETA: 0s - loss: 1.4160 - my_mae: 0.8284Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      "11610/11610 [==============================] - 1s 74us/sample - loss: 1.2838 - my_mae: 0.7826 - val_loss: 0.4503 - val_my_mae: 0.4879\n",
      "Epoch 2/2\n",
      "11610/11610 [==============================] - 0s 36us/sample - loss: 0.4418 - my_mae: 0.4782 - val_loss: 0.7718 - val_my_mae: 0.4583\n",
      "5160/5160 [==============================] - 0s 16us/sample - loss: 0.4174 - my_mae: 0.4584\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.4173873734566592, 0.45841503]"
      ]
     },
     "execution_count": 250,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled, y_train, epochs=2,\n",
    "          validation_data=(X_valid_scaled, y_valid))\n",
    "model.evaluate(X_test_scaled, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can turn this off by creating the model with `dynamic=True` (or calling `super().__init__(dynamic=True, **kwargs)` in the model's constructor):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 251,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 252,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MyModel(dynamic=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 253,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=my_mse, optimizer=\"nadam\", metrics=[my_mae])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not the custom code will be called at each iteration. Let's fit, validate and evaluate with tiny datasets to avoid getting too much output:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 254,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[5.50464653968811, 2.056248]"
      ]
     },
     "execution_count": 254,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled[:64], y_train[:64], epochs=1,\n",
    "          validation_data=(X_valid_scaled[:64], y_valid[:64]), verbose=0)\n",
    "model.evaluate(X_test_scaled[:64], y_test[:64], verbose=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alternatively, you can compile a model with `run_eagerly=True`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 255,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 256,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MyModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 257,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=my_mse, optimizer=\"nadam\", metrics=[my_mae], run_eagerly=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 258,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n",
      "Tracing MyModel.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing MyDense.call()\n",
      "Tracing loss my_mse()\n",
      "Tracing metric my_mae()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[5.50464653968811, 2.056248]"
      ]
     },
     "execution_count": 258,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train_scaled[:64], y_train[:64], epochs=1,\n",
    "          validation_data=(X_valid_scaled[:64], y_valid[:64]), verbose=0)\n",
    "model.evaluate(X_test_scaled[:64], y_test[:64], verbose=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom Optimizers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Defining custom optimizers is not very common, but in case you are one of the happy few who gets to write one, here is an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 259,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyMomentumOptimizer(keras.optimizers.Optimizer):\n",
    "    def __init__(self, learning_rate=0.001, momentum=0.9, name=\"MyMomentumOptimizer\", **kwargs):\n",
    "        \"\"\"Call super().__init__() and use _set_hyper() to store hyperparameters\"\"\"\n",
    "        super().__init__(name, **kwargs)\n",
    "        self._set_hyper(\"learning_rate\", kwargs.get(\"lr\", learning_rate)) # handle lr=learning_rate\n",
    "        self._set_hyper(\"decay\", self._initial_decay) # \n",
    "        self._set_hyper(\"momentum\", momentum)\n",
    "    \n",
    "    def _create_slots(self, var_list):\n",
    "        \"\"\"For each model variable, create the optimizer variable associated with it.\n",
    "        TensorFlow calls these optimizer variables \"slots\".\n",
    "        For momentum optimization, we need one momentum slot per model variable.\n",
    "        \"\"\"\n",
    "        for var in var_list:\n",
    "            self.add_slot(var, \"momentum\")\n",
    "\n",
    "    @tf.function\n",
    "    def _resource_apply_dense(self, grad, var):\n",
    "        \"\"\"Update the slots and perform one optimization step for one model variable\n",
    "        \"\"\"\n",
    "        var_dtype = var.dtype.base_dtype\n",
    "        lr_t = self._decayed_lr(var_dtype) # handle learning rate decay\n",
    "        momentum_var = self.get_slot(var, \"momentum\")\n",
    "        momentum_hyper = self._get_hyper(\"momentum\", var_dtype)\n",
    "        momentum_var.assign(momentum_var * momentum_hyper - (1. - momentum_hyper)* grad)\n",
    "        var.assign_add(momentum_var * lr_t)\n",
    "\n",
    "    def _resource_apply_sparse(self, grad, var):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def get_config(self):\n",
    "        base_config = super().get_config()\n",
    "        return {\n",
    "            **base_config,\n",
    "            \"learning_rate\": self._serialize_hyperparameter(\"learning_rate\"),\n",
    "            \"decay\": self._serialize_hyperparameter(\"decay\"),\n",
    "            \"momentum\": self._serialize_hyperparameter(\"momentum\"),\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 260,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 261,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 11610 samples\n",
      "Epoch 1/5\n",
      "11610/11610 [==============================] - 0s 41us/sample - loss: 3.8096\n",
      "Epoch 2/5\n",
      "11610/11610 [==============================] - 0s 20us/sample - loss: 1.4864\n",
      "Epoch 3/5\n",
      "11610/11610 [==============================] - 0s 21us/sample - loss: 0.9173\n",
      "Epoch 4/5\n",
      "11610/11610 [==============================] - 0s 20us/sample - loss: 0.7593\n",
      "Epoch 5/5\n",
      "11610/11610 [==============================] - 0s 21us/sample - loss: 0.7036\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f9fe074fd90>"
      ]
     },
     "execution_count": 261,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = keras.models.Sequential([keras.layers.Dense(1, input_shape=[8])])\n",
    "model.compile(loss=\"mse\", optimizer=MyMomentumOptimizer())\n",
    "model.fit(X_train_scaled, y_train, epochs=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exercises"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. to 11.\n",
    "See Appendix A."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 12. Implement a custom layer that performs _Layer Normalization_\n",
    "_We will use this type of layer in Chapter 15 when using Recurrent Neural Networks._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### a.\n",
    "_Exercise: The `build()` method should define two trainable weights *α* and *β*, both of shape `input_shape[-1:]` and data type `tf.float32`. *α* should be initialized with 1s, and *β* with 0s._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Solution: see below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### b.\n",
    "_Exercise: The `call()` method should compute the mean_ μ _and standard deviation_ σ _of each instance's features. For this, you can use `tf.nn.moments(inputs, axes=-1, keepdims=True)`, which returns the mean μ and the variance σ<sup>2</sup> of all instances (compute the square root of the variance to get the standard deviation). Then the function should compute and return *α*⊗(*X* - μ)/(σ + ε) + *β*, where ⊗ represents itemwise multiplication (`*`) and ε is a smoothing term (small constant to avoid division by zero, e.g., 0.001)._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 262,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LayerNormalization(keras.layers.Layer):\n",
    "    def __init__(self, eps=0.001, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.eps = eps\n",
    "\n",
    "    def build(self, batch_input_shape):\n",
    "        self.alpha = self.add_weight(\n",
    "            name=\"alpha\", shape=batch_input_shape[-1:],\n",
    "            initializer=\"ones\")\n",
    "        self.beta = self.add_weight(\n",
    "            name=\"beta\", shape=batch_input_shape[-1:],\n",
    "            initializer=\"zeros\")\n",
    "        super().build(batch_input_shape) # must be at the end\n",
    "\n",
    "    def call(self, X):\n",
    "        mean, variance = tf.nn.moments(X, axes=-1, keepdims=True)\n",
    "        return self.alpha * (X - mean) / (tf.sqrt(variance + self.eps)) + self.beta\n",
    "\n",
    "    def compute_output_shape(self, batch_input_shape):\n",
    "        return batch_input_shape\n",
    "\n",
    "    def get_config(self):\n",
    "        base_config = super().get_config()\n",
    "        return {**base_config, \"eps\": self.eps}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that making _ε_ a hyperparameter (`eps`) was not compulsory. Also note that it's preferable to compute `tf.sqrt(variance + self.eps)` rather than `tf.sqrt(variance) + self.eps`. Indeed, the derivative of sqrt(z) is undefined when z=0, so training will bomb whenever the variance vector has at least one component equal to 0. Adding _ε_ within the square root guarantees that this will never happen."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### c.\n",
    "_Exercise: Ensure that your custom layer produces the same (or very nearly the same) output as the `keras.layers.LayerNormalization` layer._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create one instance of each class, apply them to some data (e.g., the training set), and ensure that the difference is negligeable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 263,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=5.6045884e-08>"
      ]
     },
     "execution_count": 263,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = X_train.astype(np.float32)\n",
    "\n",
    "custom_layer_norm = LayerNormalization()\n",
    "keras_layer_norm = keras.layers.LayerNormalization()\n",
    "\n",
    "tf.reduce_mean(keras.losses.mean_absolute_error(\n",
    "    keras_layer_norm(X), custom_layer_norm(X)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Yep, that's close enough. To be extra sure, let's make alpha and beta completely random and compare again:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 264,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=2.2921004e-08>"
      ]
     },
     "execution_count": 264,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random_alpha = np.random.rand(X.shape[-1])\n",
    "random_beta = np.random.rand(X.shape[-1])\n",
    "\n",
    "custom_layer_norm.set_weights([random_alpha, random_beta])\n",
    "keras_layer_norm.set_weights([random_alpha, random_beta])\n",
    "\n",
    "tf.reduce_mean(keras.losses.mean_absolute_error(\n",
    "    keras_layer_norm(X), custom_layer_norm(X)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Still a negligeable difference! Our custom layer works fine."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 13. Train a model using a custom training loop to tackle the Fashion MNIST dataset\n",
    "_The Fashion MNIST dataset was introduced in Chapter 10._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### a.\n",
    "_Exercise: Display the epoch, iteration, mean training loss, and mean accuracy over each epoch (updated at each iteration), as well as the validation loss and accuracy at the end of each epoch._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 265,
   "metadata": {},
   "outputs": [],
   "source": [
    "(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()\n",
    "X_train_full = X_train_full.astype(np.float32) / 255.\n",
    "X_valid, X_train = X_train_full[:5000], X_train_full[5000:]\n",
    "y_valid, y_train = y_train_full[:5000], y_train_full[5000:]\n",
    "X_test = X_test.astype(np.float32) / 255."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 266,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 267,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Flatten(input_shape=[28, 28]),\n",
    "    keras.layers.Dense(100, activation=\"relu\"),\n",
    "    keras.layers.Dense(10, activation=\"softmax\"),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 268,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 5\n",
    "batch_size = 32\n",
    "n_steps = len(X_train) // batch_size\n",
    "optimizer = keras.optimizers.Nadam(lr=0.01)\n",
    "loss_fn = keras.losses.sparse_categorical_crossentropy\n",
    "mean_loss = keras.metrics.Mean()\n",
    "metrics = [keras.metrics.SparseCategoricalAccuracy()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 269,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ecb37329043e407d8188fd3191b19ae7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='All epochs', max=5.0, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9dd1fca102e647eb9abc55d73c3b4abd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 1/5', max=1718.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d6b7cfb1402c4b46b4093f9fb7987ce0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 2/5', max=1718.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "85e0d1d7cf3a4cf7a554353ec94c57cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 3/5', max=1718.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "15b7a028f51f4a2698a5703a1cdbceab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 4/5', max=1718.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "97dc31821ae74a95bde14dc028db1cb8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 5/5', max=1718.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "with trange(1, n_epochs + 1, desc=\"All epochs\") as epochs:\n",
    "    for epoch in epochs:\n",
    "        with trange(1, n_steps + 1, desc=\"Epoch {}/{}\".format(epoch, n_epochs)) as steps:\n",
    "            for step in steps:\n",
    "                X_batch, y_batch = random_batch(X_train, y_train)\n",
    "                with tf.GradientTape() as tape:\n",
    "                    y_pred = model(X_batch)\n",
    "                    main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))\n",
    "                    loss = tf.add_n([main_loss] + model.losses)\n",
    "                gradients = tape.gradient(loss, model.trainable_variables)\n",
    "                optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "                for variable in model.variables:\n",
    "                    if variable.constraint is not None:\n",
    "                        variable.assign(variable.constraint(variable))                    \n",
    "                status = OrderedDict()\n",
    "                mean_loss(loss)\n",
    "                status[\"loss\"] = mean_loss.result().numpy()\n",
    "                for metric in metrics:\n",
    "                    metric(y_batch, y_pred)\n",
    "                    status[metric.name] = metric.result().numpy()\n",
    "                steps.set_postfix(status)\n",
    "            y_pred = model(X_valid)\n",
    "            status[\"val_loss\"] = np.mean(loss_fn(y_valid, y_pred))\n",
    "            status[\"val_accuracy\"] = np.mean(keras.metrics.sparse_categorical_accuracy(\n",
    "                tf.constant(y_valid, dtype=np.float32), y_pred))\n",
    "            steps.set_postfix(status)\n",
    "        for metric in [mean_loss] + metrics:\n",
    "            metric.reset_states()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### b.\n",
    "_Exercise: Try using a different optimizer with a different learning rate for the upper layers and the lower layers._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 270,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 271,
   "metadata": {},
   "outputs": [],
   "source": [
    "lower_layers = keras.models.Sequential([\n",
    "    keras.layers.Flatten(input_shape=[28, 28]),\n",
    "    keras.layers.Dense(100, activation=\"relu\"),\n",
    "])\n",
    "upper_layers = keras.models.Sequential([\n",
    "    keras.layers.Dense(10, activation=\"softmax\"),\n",
    "])\n",
    "model = keras.models.Sequential([\n",
    "    lower_layers, upper_layers\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 272,
   "metadata": {},
   "outputs": [],
   "source": [
    "lower_optimizer = keras.optimizers.SGD(lr=1e-4)\n",
    "upper_optimizer = keras.optimizers.Nadam(lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 273,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 5\n",
    "batch_size = 32\n",
    "n_steps = len(X_train) // batch_size\n",
    "loss_fn = keras.losses.sparse_categorical_crossentropy\n",
    "mean_loss = keras.metrics.Mean()\n",
    "metrics = [keras.metrics.SparseCategoricalAccuracy()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 274,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f2c1a5ba73024caf9690088212c09d29",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='All epochs', max=5.0, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cd863de680554913bd71dcae512a671d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 1/5', max=1718.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "180ba8c1bfe649ad8b3823f2224094fa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 2/5', max=1718.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4168c450ea0c4019ab2774fd72f78410",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 3/5', max=1718.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dcc941e47a024bfaa0e9af49d59870f3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 4/5', max=1718.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3c2c7bd788734c0eb788c202ed4dac96",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch 5/5', max=1718.0, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "with trange(1, n_epochs + 1, desc=\"All epochs\") as epochs:\n",
    "    for epoch in epochs:\n",
    "        with trange(1, n_steps + 1, desc=\"Epoch {}/{}\".format(epoch, n_epochs)) as steps:\n",
    "            for step in steps:\n",
    "                X_batch, y_batch = random_batch(X_train, y_train)\n",
    "                with tf.GradientTape(persistent=True) as tape:\n",
    "                    y_pred = model(X_batch)\n",
    "                    main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))\n",
    "                    loss = tf.add_n([main_loss] + model.losses)\n",
    "                for layers, optimizer in ((lower_layers, lower_optimizer),\n",
    "                                          (upper_layers, upper_optimizer)):\n",
    "                    gradients = tape.gradient(loss, layers.trainable_variables)\n",
    "                    optimizer.apply_gradients(zip(gradients, layers.trainable_variables))\n",
    "                del tape\n",
    "                for variable in model.variables:\n",
    "                    if variable.constraint is not None:\n",
    "                        variable.assign(variable.constraint(variable))                    \n",
    "                status = OrderedDict()\n",
    "                mean_loss(loss)\n",
    "                status[\"loss\"] = mean_loss.result().numpy()\n",
    "                for metric in metrics:\n",
    "                    metric(y_batch, y_pred)\n",
    "                    status[metric.name] = metric.result().numpy()\n",
    "                steps.set_postfix(status)\n",
    "            y_pred = model(X_valid)\n",
    "            status[\"val_loss\"] = np.mean(loss_fn(y_valid, y_pred))\n",
    "            status[\"val_accuracy\"] = np.mean(keras.metrics.sparse_categorical_accuracy(\n",
    "                tf.constant(y_valid, dtype=np.float32), y_pred))\n",
    "            steps.set_postfix(status)\n",
    "        for metric in [mean_loss] + metrics:\n",
    "            metric.reset_states()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
